137 lines
3.8 KiB
Dart
137 lines
3.8 KiB
Dart
import 'dart:math';
|
|
|
|
import 'package:collection/collection.dart';
|
|
|
|
List<dynamic> dimensionalList(List<int> shape,
|
|
{dynamic fillValue = 0.0, dynamic Function(int)? generator}) {
|
|
if (shape.length == 1) {
|
|
if (generator != null) {
|
|
return List.generate(shape[0], generator);
|
|
}
|
|
return List.filled(shape[0], fillValue);
|
|
}
|
|
return List.generate(shape[0], (int i) {
|
|
return dimensionalList(shape.sublist(1),
|
|
fillValue: fillValue, generator: generator);
|
|
});
|
|
}
|
|
|
|
List<int> detectShape(List<dynamic> data) {
|
|
if (data[0] is! List) {
|
|
return [data.length];
|
|
}
|
|
return [data.length] + detectShape(data[0]);
|
|
}
|
|
|
|
List<dynamic> deepApply(
|
|
List<dynamic> data, dynamic Function(dynamic) callback) {
|
|
if (data[0].runtimeType != List) {
|
|
return data.map((d) {
|
|
return callback(d);
|
|
}).toList();
|
|
}
|
|
return data.map((d) {
|
|
return deepApply(d, callback);
|
|
}).toList();
|
|
}
|
|
|
|
List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
|
|
dynamic Function(dynamic, dynamic) callback) {
|
|
if (l1[0] is! List<dynamic>) {
|
|
return List.generate(l1.length, (int i) {
|
|
return callback(l1[i], l2[i]);
|
|
});
|
|
}
|
|
if (!detectShape(l1).equals(detectShape(l2))) {
|
|
throw Exception("l1 != l2");
|
|
}
|
|
return List.generate(l1.length, (int i) {
|
|
return listBroadcast(l1[i], l2[i], callback);
|
|
});
|
|
}
|
|
|
|
class Tensor {
|
|
List<dynamic> data = [];
|
|
List<int> shape = [];
|
|
|
|
Tensor(this.shape, this.data);
|
|
|
|
factory Tensor.fromShape(List<int> shape) {
|
|
return Tensor(shape, dimensionalList(shape));
|
|
}
|
|
|
|
factory Tensor.fromList(List<dynamic> data) {
|
|
return Tensor(detectShape(data), data);
|
|
}
|
|
|
|
factory Tensor.generate(List<int> shape, dynamic Function(int)? generator) {
|
|
return Tensor(
|
|
shape, dimensionalList(shape, fillValue: 0.0, generator: generator));
|
|
}
|
|
|
|
factory Tensor.random(List<int> shape) {
|
|
Random r = Random();
|
|
return Tensor.generate(shape, (int _) {
|
|
return r.nextDouble();
|
|
});
|
|
}
|
|
|
|
Tensor each(dynamic Function(dynamic) callback) {
|
|
return Tensor.fromList(deepApply(data, callback));
|
|
}
|
|
|
|
bool equals(Tensor other) {
|
|
return data.equals(other.data);
|
|
}
|
|
|
|
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments,
|
|
// implementing pytorch's stack method.
|
|
static Tensor stack(List<Tensor> tensors, [int dim = 0]) {
|
|
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim.
|
|
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim))
|
|
? 0
|
|
: dim;
|
|
|
|
List<dynamic> newShape = [
|
|
...List<dynamic>.filled(tensors.length, 1),
|
|
...tensors[0].shape
|
|
];
|
|
|
|
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
|
if (dim > 0) {
|
|
for (var i = 0; i < tensors.length; i++) {
|
|
newShape[i] *= tensors[i].shape[dim];
|
|
}
|
|
}
|
|
|
|
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
|
if (dim > 0) {
|
|
for (var i = 0; i < tensors.length; i++) {
|
|
newShape[i] *= tensors[i].shape[dim];
|
|
}
|
|
}
|
|
|
|
List<dynamic> stackedData = [];
|
|
|
|
// Iterate through the data of each tensor and concatenate it to the stackedData list.
|
|
for (var i = 0; i < newShape[0]; i++) {
|
|
int currentIndex = 0;
|
|
for (var j = 0; j < tensors.length; j++) {
|
|
if (i >= currentIndex * tensors[j].shape[dim] &&
|
|
i < (currentIndex + 1) * tensors[j].shape[dim]) {
|
|
stackedData.add(tensors[j].data[currentIndex]);
|
|
currentIndex++;
|
|
}
|
|
}
|
|
}
|
|
|
|
return Tensor.fromList(stackedData);
|
|
}
|
|
|
|
operator *(Tensor other) {
|
|
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
|
return d1 * d2;
|
|
}));
|
|
}
|
|
}
|