work more on tensors.

This commit is contained in:
Jordan Hewitt 2025-01-08 19:11:16 -08:00
parent 372b3d8442
commit 0f1b7bd5c4
2 changed files with 48 additions and 51 deletions

View File

@ -50,6 +50,16 @@ List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
}); });
} }
List<dynamic> listStack(List<List<dynamic>> lists, int dim) {
List<dynamic> newData = [];
for (int i = 0; i < lists.first.data.length; ++i) {
List<dynamic> subData = lists.map((l => l.data[i]))
}
return newData;
}
class Tensor { class Tensor {
List<dynamic> data = []; List<dynamic> data = [];
List<int> shape = []; List<int> shape = [];
@ -84,48 +94,30 @@ class Tensor {
return data.equals(other.data); return data.equals(other.data);
} }
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments, static Tensor cat(List<Tensor> tensors) {
// implementing pytorch's stack method. return Tensor.fromList(tensors.map((el) {
static Tensor stack(List<Tensor> tensors, [int dim = 0]) { return el.data;
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim. }).toList());
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim)) }
? 0
: dim;
List<dynamic> newShape = [ static Tensor stack(List<Tensor> tensors, {int dim = 0}) {
...List<dynamic>.filled(tensors.length, 1), int maxShape = tensors.map((Tensor t) {
...tensors[0].shape return t.shape.max;
]; }).max;
if (dim < 0 || dim > maxShape) {
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. throw Exception("Invalid dimension");
if (dim > 0) { }
for (var i = 0; i < tensors.length; i++) { List<double> newList = [];
newShape[i] *= tensors[i].shape[dim]; List<int> path = [];
for (int i = 0; i < tensors.length; i++) {
} }
} }
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. List<dynamic> toList({int depth = 0}) {
if (dim > 0) { if (depth == 0 || ((data.length > 0 && this.data[0].runtimetype == double))) {
for (var i = 0; i < tensors.length; i++) { return data
newShape[i] *= tensors[i].shape[dim];
} }
} return data.map((d) { return d.toList(depth-1) })
List<dynamic> stackedData = [];
// Iterate through the data of each tensor and concatenate it to the stackedData list.
for (var i = 0; i < newShape[0]; i++) {
int currentIndex = 0;
for (var j = 0; j < tensors.length; j++) {
if (i >= currentIndex * tensors[j].shape[dim] &&
i < (currentIndex + 1) * tensors[j].shape[dim]) {
stackedData.add(tensors[j].data[currentIndex]);
currentIndex++;
}
}
}
return Tensor.fromList(stackedData);
} }
operator *(Tensor other) { operator *(Tensor other) {

View File

@ -79,22 +79,27 @@ void main() {
test("tensors can be stacked", () { test("tensors can be stacked", () {
List<List<double>> l = [ List<List<double>> l = [
[0.3367, 0.1288, 02345], [1, 2, 3],
[0.2303, -1.1229, -0.1863] [4, 5, 6]
]; ];
Tensor baseTensor = Tensor.fromList(l); Tensor baseTensor = Tensor.fromList(l);
Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]); Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]);
expect( Tensor expectedStacked1 = Tensor.fromList([l, l]);
stacked_1.data, expect(stacked_1.equals(expectedStacked1), isTrue);
equals([
List<List<List<double>>> ld1 = [
[ [
[0.3367, 0.1288, 0.2345], [1, 2, 3],
[0.2303, -1.1229, -0.1863] [1, 2, 3],
], ],
[ [
[0.3367, 0.1288, 0.2345], [4, 5, 6],
[0.2303, -1.1229, -0.1863] [4, 5, 6]
] ]
])); ];
Tensor stacked_2 = Tensor.stack([baseTensor, baseTensor], dim: 1);
Tensor expectedStacked2 = Tensor.fromList(ld1);
expect(stacked_2.equals(expectedStacked2), isTrue);
}); });
} }