import 'dart:math'; import 'package:collection/collection.dart'; List dimensionalList(List shape, {dynamic fillValue = 0.0, dynamic Function(int)? generator}) { if (shape.length == 1) { if (generator != null) { return List.generate(shape[0], generator); } return List.filled(shape[0], fillValue); } return List.generate(shape[0], (int i) { return dimensionalList(shape.sublist(1), fillValue: fillValue, generator: generator); }); } List detectShape(List data) { if (data[0] is! List) { return [data.length]; } return [data.length] + detectShape(data[0]); } List deepApply( List data, dynamic Function(dynamic) callback) { if (data[0].runtimeType != List) { return data.map((d) { return callback(d); }).toList(); } return data.map((d) { return deepApply(d, callback); }).toList(); } List listBroadcast(List l1, List l2, dynamic Function(dynamic, dynamic) callback) { if (l1[0] is! List) { return List.generate(l1.length, (int i) { return callback(l1[i], l2[i]); }); } if (!detectShape(l1).equals(detectShape(l2))) { throw Exception("l1 != l2"); } return List.generate(l1.length, (int i) { return listBroadcast(l1[i], l2[i], callback); }); } class Tensor { List data = []; List shape = []; Tensor(this.shape, this.data); factory Tensor.fromShape(List shape) { return Tensor(shape, dimensionalList(shape)); } factory Tensor.fromList(List data) { return Tensor(detectShape(data), data); } factory Tensor.generate(List shape, dynamic Function(int)? generator) { return Tensor( shape, dimensionalList(shape, fillValue: 0.0, generator: generator)); } factory Tensor.random(List shape) { Random r = Random(); return Tensor.generate(shape, (int _) { return r.nextDouble(); }); } Tensor each(dynamic Function(dynamic) callback) { return Tensor.fromList(deepApply(data, callback)); } bool equals(Tensor other) { return data.equals(other.data); } // Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments, // implementing pytorch's stack method. static Tensor stack(List tensors, [int dim = 0]) { // If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim. dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim)) ? 0 : dim; List newShape = [ ...List.filled(tensors.length, 1), ...tensors[0].shape ]; // If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. if (dim > 0) { for (var i = 0; i < tensors.length; i++) { newShape[i] *= tensors[i].shape[dim]; } } // If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. if (dim > 0) { for (var i = 0; i < tensors.length; i++) { newShape[i] *= tensors[i].shape[dim]; } } List stackedData = []; // Iterate through the data of each tensor and concatenate it to the stackedData list. for (var i = 0; i < newShape[0]; i++) { int currentIndex = 0; for (var j = 0; j < tensors.length; j++) { if (i >= currentIndex * tensors[j].shape[dim] && i < (currentIndex + 1) * tensors[j].shape[dim]) { stackedData.add(tensors[j].data[currentIndex]); currentIndex++; } } } return Tensor.fromList(stackedData); } operator *(Tensor other) { return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) { return d1 * d2; })); } }