diff --git a/lib/src/splat/tensor.dart b/lib/src/splat/tensor.dart index 790570b..46c1efb 100644 --- a/lib/src/splat/tensor.dart +++ b/lib/src/splat/tensor.dart @@ -17,7 +17,7 @@ List dimensionalList(List shape, } List detectShape(List data) { - if (data.runtimeType != List) { + if (data[0] is! List) { return [data.length]; } return [data.length] + detectShape(data[0]); @@ -35,10 +35,12 @@ List deepApply( }).toList(); } -List listBroadcast( - dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) { - if (!(l1.runtimeType == List && l2.runtimeType == List)) { - return callback(l1, l2); +List listBroadcast(List l1, List l2, + dynamic Function(dynamic, dynamic) callback) { + if (l1[0] is! List) { + return List.generate(l1.length, (int i) { + return callback(l1[i], l2[i]); + }); } if (!detectShape(l1).equals(detectShape(l2))) { throw Exception("l1 != l2"); @@ -78,6 +80,10 @@ class Tensor { return Tensor.fromList(deepApply(data, callback)); } + bool equals(Tensor other) { + return data.equals(other.data); + } + operator *(Tensor other) { return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) { return d1 * d2; diff --git a/lib/src/test/splat/tensor_test.dart b/lib/src/test/splat/tensor_test.dart index 52d8517..07974ba 100644 --- a/lib/src/test/splat/tensor_test.dart +++ b/lib/src/test/splat/tensor_test.dart @@ -74,6 +74,6 @@ void main() { [2, 4, 6], [8, 10, 12] ]); - expect(t1 * t2, equals(expected)); + expect((t1 * t2).data, equals(expected.data)); }); }