completed and tested broadcast.
This commit is contained in:
parent
6fcff57d68
commit
dffe5cee98
@ -17,7 +17,7 @@ List<dynamic> dimensionalList(List<int> shape,
|
||||
}
|
||||
|
||||
List<int> detectShape(List<dynamic> data) {
|
||||
if (data.runtimeType != List) {
|
||||
if (data[0] is! List) {
|
||||
return [data.length];
|
||||
}
|
||||
return [data.length] + detectShape(data[0]);
|
||||
@ -35,10 +35,12 @@ List<dynamic> deepApply(
|
||||
}).toList();
|
||||
}
|
||||
|
||||
List<dynamic> listBroadcast(
|
||||
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) {
|
||||
if (!(l1.runtimeType == List && l2.runtimeType == List)) {
|
||||
return callback(l1, l2);
|
||||
List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
|
||||
dynamic Function(dynamic, dynamic) callback) {
|
||||
if (l1[0] is! List<dynamic>) {
|
||||
return List.generate(l1.length, (int i) {
|
||||
return callback(l1[i], l2[i]);
|
||||
});
|
||||
}
|
||||
if (!detectShape(l1).equals(detectShape(l2))) {
|
||||
throw Exception("l1 != l2");
|
||||
@ -78,6 +80,10 @@ class Tensor {
|
||||
return Tensor.fromList(deepApply(data, callback));
|
||||
}
|
||||
|
||||
bool equals(Tensor other) {
|
||||
return data.equals(other.data);
|
||||
}
|
||||
|
||||
operator *(Tensor other) {
|
||||
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
||||
return d1 * d2;
|
||||
|
@ -74,6 +74,6 @@ void main() {
|
||||
[2, 4, 6],
|
||||
[8, 10, 12]
|
||||
]);
|
||||
expect(t1 * t2, equals(expected));
|
||||
expect((t1 * t2).data, equals(expected.data));
|
||||
});
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user