add each and broadcast methods.
This commit is contained in:
parent
859f540f4a
commit
6fcff57d68
@ -2,88 +2,85 @@ import 'dart:math';
|
||||
|
||||
import 'package:collection/collection.dart';
|
||||
|
||||
class Tensor extends DelegatingList<DelegatingList<double>> {
|
||||
Tensor(super.base);
|
||||
|
||||
factory Tensor.fromList(List<List<double>> lst) {
|
||||
return Tensor(DelegatingList(lst.map((e) => DelegatingList(e)).toList()));
|
||||
}
|
||||
|
||||
Tensor each(Function(double, int, int) f) {
|
||||
Tensor other = Tensor([]);
|
||||
for (int j = 0; j < length; ++j) {
|
||||
other[j] = const DelegatingList([]);
|
||||
for (int k = 0; k < this[j].length; ++k) {
|
||||
other[j][k] = f(this[j][k], j, k);
|
||||
}
|
||||
}
|
||||
return other;
|
||||
}
|
||||
|
||||
/// Generate a random tensor of shape `shape`
|
||||
static Tensor random(List<int> shape) {
|
||||
Random r = Random();
|
||||
|
||||
int d1 = 0, d2 = 0;
|
||||
List<dynamic> dimensionalList(List<int> shape,
|
||||
{dynamic fillValue = 0.0, dynamic Function(int)? generator}) {
|
||||
if (shape.length == 1) {
|
||||
d1 = shape[0];
|
||||
} else if (shape.length == 2) {
|
||||
d1 = shape[0];
|
||||
d2 = shape[1];
|
||||
} else if (shape.length == 3) {
|
||||
// d3 = shapes[2];
|
||||
if (generator != null) {
|
||||
return List.generate(shape[0], generator);
|
||||
}
|
||||
Tensor ret = Tensor(List.filled(d1, DelegatingList(List.filled(d2, 0.0))));
|
||||
return List.filled(shape[0], fillValue);
|
||||
}
|
||||
return List.generate(shape[0], (int i) {
|
||||
return dimensionalList(shape.sublist(1),
|
||||
fillValue: fillValue, generator: generator);
|
||||
});
|
||||
}
|
||||
|
||||
for (int i = 0; i < d1; ++i) {
|
||||
for (int j = 0; j < d2; ++j) {
|
||||
ret[i][j] = r.nextDouble();
|
||||
List<int> detectShape(List<dynamic> data) {
|
||||
if (data.runtimeType != List) {
|
||||
return [data.length];
|
||||
}
|
||||
return [data.length] + detectShape(data[0]);
|
||||
}
|
||||
|
||||
List<dynamic> deepApply(
|
||||
List<dynamic> data, dynamic Function(dynamic) callback) {
|
||||
if (data[0].runtimeType != List) {
|
||||
return data.map((d) {
|
||||
return callback(d);
|
||||
}).toList();
|
||||
}
|
||||
return data.map((d) {
|
||||
return deepApply(d, callback);
|
||||
}).toList();
|
||||
}
|
||||
|
||||
List<dynamic> listBroadcast(
|
||||
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) {
|
||||
if (!(l1.runtimeType == List && l2.runtimeType == List)) {
|
||||
return callback(l1, l2);
|
||||
}
|
||||
if (!detectShape(l1).equals(detectShape(l2))) {
|
||||
throw Exception("l1 != l2");
|
||||
}
|
||||
return List.generate(l1.length, (int i) {
|
||||
return listBroadcast(l1[i], l2[i], callback);
|
||||
});
|
||||
}
|
||||
|
||||
class Tensor {
|
||||
List<dynamic> data = [];
|
||||
List<int> shape = [];
|
||||
|
||||
Tensor(this.shape, this.data);
|
||||
|
||||
factory Tensor.fromShape(List<int> shape) {
|
||||
return Tensor(shape, dimensionalList(shape));
|
||||
}
|
||||
|
||||
return ret;
|
||||
factory Tensor.fromList(List<dynamic> data) {
|
||||
return Tensor(detectShape(data), data);
|
||||
}
|
||||
|
||||
static Tensor stack(Tensor tensors, {int axis = 0}) {
|
||||
if (axis < -1 || axis > tensors.length - 1) {
|
||||
throw ArgumentError('Invalid axis value');
|
||||
factory Tensor.generate(List<int> shape, dynamic Function(int)? generator) {
|
||||
return Tensor(
|
||||
shape, dimensionalList(shape, fillValue: 0.0, generator: generator));
|
||||
}
|
||||
|
||||
int newAxisSize = tensors.length;
|
||||
for (var tensor in tensors) {
|
||||
newAxisSize *= tensor.length;
|
||||
factory Tensor.random(List<int> shape) {
|
||||
Random r = Random();
|
||||
return Tensor.generate(shape, (int _) {
|
||||
return r.nextDouble();
|
||||
});
|
||||
}
|
||||
|
||||
Tensor result = Tensor([]);
|
||||
for (int i = 0; i < newAxisSize; i++) {
|
||||
int index = i;
|
||||
int currentAxisIndex = axis;
|
||||
List<int> currentAxisIndexes = List.filled(tensors.length, -1);
|
||||
int currentTensorIndex = 0;
|
||||
|
||||
while (currentAxisIndexes[currentTensorIndex] < tensors.length) {
|
||||
if (currentAxisIndexes[currentTensorIndex] == currentAxisIndex) {
|
||||
index = currentAxisIndexes[currentTensorIndex] +
|
||||
(index ~/ tensors.length);
|
||||
currentAxisIndexes[currentTensorIndex]++;
|
||||
}
|
||||
currentAxisIndex += (axis > 0 ? 1 : -1);
|
||||
currentTensorIndex =
|
||||
(currentAxisIndex < 0 || currentTensorIndex >= tensors.length)
|
||||
? 0
|
||||
: currentTensorIndex + 1;
|
||||
}
|
||||
|
||||
result.add(tensors[
|
||||
currentTensorIndex]); // Access tensors[currentTensorIndex] as a List<double> rather than using the index operator [] with it
|
||||
}
|
||||
|
||||
return Tensor(result);
|
||||
Tensor each(dynamic Function(dynamic) callback) {
|
||||
return Tensor.fromList(deepApply(data, callback));
|
||||
}
|
||||
|
||||
operator *(Tensor other) {
|
||||
return each((d, i, j) {
|
||||
return [i][j] * other[i][j];
|
||||
});
|
||||
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
||||
return d1 * d2;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
@ -1,36 +1,79 @@
|
||||
import 'dart:math';
|
||||
|
||||
import 'package:archimedes_test/src/splat/tensor.dart';
|
||||
import 'package:collection/collection.dart';
|
||||
import 'package:test/test.dart';
|
||||
|
||||
void main() {
|
||||
group('stack', () {
|
||||
test('A random tensor can be generated', () {
|
||||
Tensor generated = Tensor.random([2, 4]);
|
||||
expect(generated.length, equals(2));
|
||||
expect(generated[0].length, equals(4));
|
||||
});
|
||||
group("tensor", () {
|
||||
test("Tensors can be constructed from a shape", () {
|
||||
Tensor t1 = Tensor.fromShape([2]);
|
||||
|
||||
test("A tensor can be stacked", () {
|
||||
Tensor x = Tensor(const DelegatingList([
|
||||
DelegatingList([0.3367, 0.1288, 0.2345]),
|
||||
DelegatingList([0.2303, -1.1229, -0.1863])
|
||||
]));
|
||||
expect(t1.data, equals([0, 0]));
|
||||
|
||||
Tensor t2 = Tensor.fromShape([2, 3, 4]);
|
||||
|
||||
expect(Tensor.stack(x)[0], equals(x));
|
||||
expect(
|
||||
Tensor.stack(x),
|
||||
equals(Tensor.fromList([
|
||||
t2.data,
|
||||
equals([
|
||||
[
|
||||
[
|
||||
[0.3367, 0.1288, 0.2345],
|
||||
[0.3367, 0.1288, 0.2345],
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[0.2303, -1.1229, -0.1863],
|
||||
[0.2303, -1.1229, -0.1863]
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
])));
|
||||
]));
|
||||
});
|
||||
});
|
||||
test("Tensors can be constructed from a list", () {
|
||||
Tensor t1 = Tensor.fromList([
|
||||
[1.0, 2.0, 3.0],
|
||||
[1.0, 2.0, 3.0]
|
||||
]);
|
||||
|
||||
expect(t1.shape, equals([2, 3]));
|
||||
});
|
||||
|
||||
test("Tensor can be generated", () {
|
||||
Random r = Random();
|
||||
Tensor t = Tensor.generate([1, 5], (int i) {
|
||||
return r.nextDouble();
|
||||
});
|
||||
expect(t.data, isNotNull);
|
||||
});
|
||||
|
||||
test("Function can be applied", () {
|
||||
Random r = Random();
|
||||
Tensor t = Tensor.generate([1, 2], (int i) {
|
||||
return r.nextDouble();
|
||||
});
|
||||
|
||||
Tensor t2 = t.each((dynamic d) {
|
||||
return d * 100.0;
|
||||
});
|
||||
|
||||
for (int i = 0; i < t2.data.length; ++i) {
|
||||
expect(0 < t2.data[0][i] && t2.data[0][i] < 100, isTrue);
|
||||
}
|
||||
});
|
||||
|
||||
test("Can broadcast 2 tensors", () {
|
||||
Tensor t1 = Tensor.fromList([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
]);
|
||||
Tensor t2 = Tensor.fromList([
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]
|
||||
]);
|
||||
Tensor expected = Tensor.fromList([
|
||||
[2, 4, 6],
|
||||
[8, 10, 12]
|
||||
]);
|
||||
expect(t1 * t2, equals(expected));
|
||||
});
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user