completed and tested broadcast.

This commit is contained in:
Jordan 2024-12-14 06:13:00 -08:00
parent 6fcff57d68
commit dffe5cee98
2 changed files with 12 additions and 6 deletions

View File

@ -17,7 +17,7 @@ List<dynamic> dimensionalList(List<int> shape,
} }
List<int> detectShape(List<dynamic> data) { List<int> detectShape(List<dynamic> data) {
if (data.runtimeType != List) { if (data[0] is! List) {
return [data.length]; return [data.length];
} }
return [data.length] + detectShape(data[0]); return [data.length] + detectShape(data[0]);
@ -35,10 +35,12 @@ List<dynamic> deepApply(
}).toList(); }).toList();
} }
List<dynamic> listBroadcast( List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) { dynamic Function(dynamic, dynamic) callback) {
if (!(l1.runtimeType == List && l2.runtimeType == List)) { if (l1[0] is! List<dynamic>) {
return callback(l1, l2); return List.generate(l1.length, (int i) {
return callback(l1[i], l2[i]);
});
} }
if (!detectShape(l1).equals(detectShape(l2))) { if (!detectShape(l1).equals(detectShape(l2))) {
throw Exception("l1 != l2"); throw Exception("l1 != l2");
@ -78,6 +80,10 @@ class Tensor {
return Tensor.fromList(deepApply(data, callback)); return Tensor.fromList(deepApply(data, callback));
} }
bool equals(Tensor other) {
return data.equals(other.data);
}
operator *(Tensor other) { operator *(Tensor other) {
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) { return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
return d1 * d2; return d1 * d2;

View File

@ -74,6 +74,6 @@ void main() {
[2, 4, 6], [2, 4, 6],
[8, 10, 12] [8, 10, 12]
]); ]);
expect(t1 * t2, equals(expected)); expect((t1 * t2).data, equals(expected.data));
}); });
} }