add each and broadcast methods.

This commit is contained in:
Jordan Hewitt 2024-12-13 17:18:12 -08:00
parent 859f540f4a
commit 6fcff57d68
2 changed files with 137 additions and 97 deletions

View File

@ -2,88 +2,85 @@ import 'dart:math';
import 'package:collection/collection.dart'; import 'package:collection/collection.dart';
class Tensor extends DelegatingList<DelegatingList<double>> { List<dynamic> dimensionalList(List<int> shape,
Tensor(super.base); {dynamic fillValue = 0.0, dynamic Function(int)? generator}) {
if (shape.length == 1) {
factory Tensor.fromList(List<List<double>> lst) { if (generator != null) {
return Tensor(DelegatingList(lst.map((e) => DelegatingList(e)).toList())); return List.generate(shape[0], generator);
}
Tensor each(Function(double, int, int) f) {
Tensor other = Tensor([]);
for (int j = 0; j < length; ++j) {
other[j] = const DelegatingList([]);
for (int k = 0; k < this[j].length; ++k) {
other[j][k] = f(this[j][k], j, k);
}
} }
return other; return List.filled(shape[0], fillValue);
}
return List.generate(shape[0], (int i) {
return dimensionalList(shape.sublist(1),
fillValue: fillValue, generator: generator);
});
}
List<int> detectShape(List<dynamic> data) {
if (data.runtimeType != List) {
return [data.length];
}
return [data.length] + detectShape(data[0]);
}
List<dynamic> deepApply(
List<dynamic> data, dynamic Function(dynamic) callback) {
if (data[0].runtimeType != List) {
return data.map((d) {
return callback(d);
}).toList();
}
return data.map((d) {
return deepApply(d, callback);
}).toList();
}
List<dynamic> listBroadcast(
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) {
if (!(l1.runtimeType == List && l2.runtimeType == List)) {
return callback(l1, l2);
}
if (!detectShape(l1).equals(detectShape(l2))) {
throw Exception("l1 != l2");
}
return List.generate(l1.length, (int i) {
return listBroadcast(l1[i], l2[i], callback);
});
}
class Tensor {
List<dynamic> data = [];
List<int> shape = [];
Tensor(this.shape, this.data);
factory Tensor.fromShape(List<int> shape) {
return Tensor(shape, dimensionalList(shape));
} }
/// Generate a random tensor of shape `shape` factory Tensor.fromList(List<dynamic> data) {
static Tensor random(List<int> shape) { return Tensor(detectShape(data), data);
}
factory Tensor.generate(List<int> shape, dynamic Function(int)? generator) {
return Tensor(
shape, dimensionalList(shape, fillValue: 0.0, generator: generator));
}
factory Tensor.random(List<int> shape) {
Random r = Random(); Random r = Random();
return Tensor.generate(shape, (int _) {
int d1 = 0, d2 = 0; return r.nextDouble();
if (shape.length == 1) { });
d1 = shape[0];
} else if (shape.length == 2) {
d1 = shape[0];
d2 = shape[1];
} else if (shape.length == 3) {
// d3 = shapes[2];
}
Tensor ret = Tensor(List.filled(d1, DelegatingList(List.filled(d2, 0.0))));
for (int i = 0; i < d1; ++i) {
for (int j = 0; j < d2; ++j) {
ret[i][j] = r.nextDouble();
}
}
return ret;
} }
static Tensor stack(Tensor tensors, {int axis = 0}) { Tensor each(dynamic Function(dynamic) callback) {
if (axis < -1 || axis > tensors.length - 1) { return Tensor.fromList(deepApply(data, callback));
throw ArgumentError('Invalid axis value');
}
int newAxisSize = tensors.length;
for (var tensor in tensors) {
newAxisSize *= tensor.length;
}
Tensor result = Tensor([]);
for (int i = 0; i < newAxisSize; i++) {
int index = i;
int currentAxisIndex = axis;
List<int> currentAxisIndexes = List.filled(tensors.length, -1);
int currentTensorIndex = 0;
while (currentAxisIndexes[currentTensorIndex] < tensors.length) {
if (currentAxisIndexes[currentTensorIndex] == currentAxisIndex) {
index = currentAxisIndexes[currentTensorIndex] +
(index ~/ tensors.length);
currentAxisIndexes[currentTensorIndex]++;
}
currentAxisIndex += (axis > 0 ? 1 : -1);
currentTensorIndex =
(currentAxisIndex < 0 || currentTensorIndex >= tensors.length)
? 0
: currentTensorIndex + 1;
}
result.add(tensors[
currentTensorIndex]); // Access tensors[currentTensorIndex] as a List<double> rather than using the index operator [] with it
}
return Tensor(result);
} }
operator *(Tensor other) { operator *(Tensor other) {
return each((d, i, j) { return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
return [i][j] * other[i][j]; return d1 * d2;
}); }));
} }
} }

View File

@ -1,36 +1,79 @@
import 'dart:math';
import 'package:archimedes_test/src/splat/tensor.dart'; import 'package:archimedes_test/src/splat/tensor.dart';
import 'package:collection/collection.dart'; import 'package:collection/collection.dart';
import 'package:test/test.dart'; import 'package:test/test.dart';
void main() { void main() {
group('stack', () { group("tensor", () {
test('A random tensor can be generated', () { test("Tensors can be constructed from a shape", () {
Tensor generated = Tensor.random([2, 4]); Tensor t1 = Tensor.fromShape([2]);
expect(generated.length, equals(2));
expect(generated[0].length, equals(4));
});
test("A tensor can be stacked", () { expect(t1.data, equals([0, 0]));
Tensor x = Tensor(const DelegatingList([
DelegatingList([0.3367, 0.1288, 0.2345]), Tensor t2 = Tensor.fromShape([2, 3, 4]);
DelegatingList([0.2303, -1.1229, -0.1863])
]));
expect(Tensor.stack(x)[0], equals(x));
expect( expect(
Tensor.stack(x), t2.data,
equals(Tensor.fromList([ equals([
[ [
[ [0.0, 0.0, 0.0, 0.0],
[0.3367, 0.1288, 0.2345], [0.0, 0.0, 0.0, 0.0],
[0.3367, 0.1288, 0.2345], [0.0, 0.0, 0.0, 0.0]
], ],
[ [
[0.2303, -1.1229, -0.1863], [0.0, 0.0, 0.0, 0.0],
[0.2303, -1.1229, -0.1863] [0.0, 0.0, 0.0, 0.0],
] [0.0, 0.0, 0.0, 0.0]
] ]
]))); ]));
}); });
}); });
test("Tensors can be constructed from a list", () {
Tensor t1 = Tensor.fromList([
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0]
]);
expect(t1.shape, equals([2, 3]));
});
test("Tensor can be generated", () {
Random r = Random();
Tensor t = Tensor.generate([1, 5], (int i) {
return r.nextDouble();
});
expect(t.data, isNotNull);
});
test("Function can be applied", () {
Random r = Random();
Tensor t = Tensor.generate([1, 2], (int i) {
return r.nextDouble();
});
Tensor t2 = t.each((dynamic d) {
return d * 100.0;
});
for (int i = 0; i < t2.data.length; ++i) {
expect(0 < t2.data[0][i] && t2.data[0][i] < 100, isTrue);
}
});
test("Can broadcast 2 tensors", () {
Tensor t1 = Tensor.fromList([
[1, 2, 3],
[4, 5, 6]
]);
Tensor t2 = Tensor.fromList([
[2, 2, 2],
[2, 2, 2]
]);
Tensor expected = Tensor.fromList([
[2, 4, 6],
[8, 10, 12]
]);
expect(t1 * t2, equals(expected));
});
} }