2024-12-19 06:20:19 -08:00

137 lines
3.8 KiB
Dart

import 'dart:math';
import 'package:collection/collection.dart';
List<dynamic> dimensionalList(List<int> shape,
{dynamic fillValue = 0.0, dynamic Function(int)? generator}) {
if (shape.length == 1) {
if (generator != null) {
return List.generate(shape[0], generator);
}
return List.filled(shape[0], fillValue);
}
return List.generate(shape[0], (int i) {
return dimensionalList(shape.sublist(1),
fillValue: fillValue, generator: generator);
});
}
List<int> detectShape(List<dynamic> data) {
if (data[0] is! List) {
return [data.length];
}
return [data.length] + detectShape(data[0]);
}
List<dynamic> deepApply(
List<dynamic> data, dynamic Function(dynamic) callback) {
if (data[0].runtimeType != List) {
return data.map((d) {
return callback(d);
}).toList();
}
return data.map((d) {
return deepApply(d, callback);
}).toList();
}
List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
dynamic Function(dynamic, dynamic) callback) {
if (l1[0] is! List<dynamic>) {
return List.generate(l1.length, (int i) {
return callback(l1[i], l2[i]);
});
}
if (!detectShape(l1).equals(detectShape(l2))) {
throw Exception("l1 != l2");
}
return List.generate(l1.length, (int i) {
return listBroadcast(l1[i], l2[i], callback);
});
}
class Tensor {
List<dynamic> data = [];
List<int> shape = [];
Tensor(this.shape, this.data);
factory Tensor.fromShape(List<int> shape) {
return Tensor(shape, dimensionalList(shape));
}
factory Tensor.fromList(List<dynamic> data) {
return Tensor(detectShape(data), data);
}
factory Tensor.generate(List<int> shape, dynamic Function(int)? generator) {
return Tensor(
shape, dimensionalList(shape, fillValue: 0.0, generator: generator));
}
factory Tensor.random(List<int> shape) {
Random r = Random();
return Tensor.generate(shape, (int _) {
return r.nextDouble();
});
}
Tensor each(dynamic Function(dynamic) callback) {
return Tensor.fromList(deepApply(data, callback));
}
bool equals(Tensor other) {
return data.equals(other.data);
}
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments,
// implementing pytorch's stack method.
static Tensor stack(List<Tensor> tensors, [int dim = 0]) {
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim.
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim))
? 0
: dim;
List<dynamic> newShape = [
...List<dynamic>.filled(tensors.length, 1),
...tensors[0].shape
];
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
if (dim > 0) {
for (var i = 0; i < tensors.length; i++) {
newShape[i] *= tensors[i].shape[dim];
}
}
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
if (dim > 0) {
for (var i = 0; i < tensors.length; i++) {
newShape[i] *= tensors[i].shape[dim];
}
}
List<dynamic> stackedData = [];
// Iterate through the data of each tensor and concatenate it to the stackedData list.
for (var i = 0; i < newShape[0]; i++) {
int currentIndex = 0;
for (var j = 0; j < tensors.length; j++) {
if (i >= currentIndex * tensors[j].shape[dim] &&
i < (currentIndex + 1) * tensors[j].shape[dim]) {
stackedData.add(tensors[j].data[currentIndex]);
currentIndex++;
}
}
}
return Tensor.fromList(stackedData);
}
operator *(Tensor other) {
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
return d1 * d2;
}));
}
}