completed and tested broadcast.

This commit is contained in:
Jordan 2024-12-14 06:13:00 -08:00
parent 6fcff57d68
commit dffe5cee98
2 changed files with 12 additions and 6 deletions

View File

@ -17,7 +17,7 @@ List<dynamic> dimensionalList(List<int> shape,
}
List<int> detectShape(List<dynamic> data) {
if (data.runtimeType != List) {
if (data[0] is! List) {
return [data.length];
}
return [data.length] + detectShape(data[0]);
@ -35,10 +35,12 @@ List<dynamic> deepApply(
}).toList();
}
List<dynamic> listBroadcast(
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) {
if (!(l1.runtimeType == List && l2.runtimeType == List)) {
return callback(l1, l2);
List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
dynamic Function(dynamic, dynamic) callback) {
if (l1[0] is! List<dynamic>) {
return List.generate(l1.length, (int i) {
return callback(l1[i], l2[i]);
});
}
if (!detectShape(l1).equals(detectShape(l2))) {
throw Exception("l1 != l2");
@ -78,6 +80,10 @@ class Tensor {
return Tensor.fromList(deepApply(data, callback));
}
bool equals(Tensor other) {
return data.equals(other.data);
}
operator *(Tensor other) {
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
return d1 * d2;

View File

@ -74,6 +74,6 @@ void main() {
[2, 4, 6],
[8, 10, 12]
]);
expect(t1 * t2, equals(expected));
expect((t1 * t2).data, equals(expected.data));
});
}