From 0f1b7bd5c48a15e00a490c067d482c19e5be9367 Mon Sep 17 00:00:00 2001 From: Jordan Hewitt Date: Wed, 8 Jan 2025 19:11:16 -0800 Subject: [PATCH] work more on tensors. --- lib/src/splat/tensor.dart | 66 +++++++++++++---------------- lib/src/test/splat/tensor_test.dart | 33 +++++++++------ 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/lib/src/splat/tensor.dart b/lib/src/splat/tensor.dart index fa1a8b7..c5954ea 100644 --- a/lib/src/splat/tensor.dart +++ b/lib/src/splat/tensor.dart @@ -50,6 +50,16 @@ List listBroadcast(List l1, List l2, }); } +List listStack(List> lists, int dim) { + List newData = []; + + for (int i = 0; i < lists.first.data.length; ++i) { + List subData = lists.map((l => l.data[i])) + } + + return newData; +} + class Tensor { List data = []; List shape = []; @@ -84,48 +94,30 @@ class Tensor { return data.equals(other.data); } -// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments, -// implementing pytorch's stack method. - static Tensor stack(List tensors, [int dim = 0]) { - // If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim. - dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim)) - ? 0 - : dim; + static Tensor cat(List tensors) { + return Tensor.fromList(tensors.map((el) { + return el.data; + }).toList()); + } - List newShape = [ - ...List.filled(tensors.length, 1), - ...tensors[0].shape - ]; - - // If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. - if (dim > 0) { - for (var i = 0; i < tensors.length; i++) { - newShape[i] *= tensors[i].shape[dim]; - } + static Tensor stack(List tensors, {int dim = 0}) { + int maxShape = tensors.map((Tensor t) { + return t.shape.max; + }).max; + if (dim < 0 || dim > maxShape) { + throw Exception("Invalid dimension"); } - - // If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. - if (dim > 0) { - for (var i = 0; i < tensors.length; i++) { - newShape[i] *= tensors[i].shape[dim]; - } + List newList = []; + List path = []; + for (int i = 0; i < tensors.length; i++) { } + } - List stackedData = []; - - // Iterate through the data of each tensor and concatenate it to the stackedData list. - for (var i = 0; i < newShape[0]; i++) { - int currentIndex = 0; - for (var j = 0; j < tensors.length; j++) { - if (i >= currentIndex * tensors[j].shape[dim] && - i < (currentIndex + 1) * tensors[j].shape[dim]) { - stackedData.add(tensors[j].data[currentIndex]); - currentIndex++; - } - } + List toList({int depth = 0}) { + if (depth == 0 || ((data.length > 0 && this.data[0].runtimetype == double))) { + return data } - - return Tensor.fromList(stackedData); + return data.map((d) { return d.toList(depth-1) }) } operator *(Tensor other) { diff --git a/lib/src/test/splat/tensor_test.dart b/lib/src/test/splat/tensor_test.dart index 48dec92..b68f689 100644 --- a/lib/src/test/splat/tensor_test.dart +++ b/lib/src/test/splat/tensor_test.dart @@ -79,22 +79,27 @@ void main() { test("tensors can be stacked", () { List> l = [ - [0.3367, 0.1288, 02345], - [0.2303, -1.1229, -0.1863] + [1, 2, 3], + [4, 5, 6] ]; Tensor baseTensor = Tensor.fromList(l); Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]); - expect( - stacked_1.data, - equals([ - [ - [0.3367, 0.1288, 0.2345], - [0.2303, -1.1229, -0.1863] - ], - [ - [0.3367, 0.1288, 0.2345], - [0.2303, -1.1229, -0.1863] - ] - ])); + Tensor expectedStacked1 = Tensor.fromList([l, l]); + expect(stacked_1.equals(expectedStacked1), isTrue); + + List>> ld1 = [ + [ + [1, 2, 3], + [1, 2, 3], + ], + [ + [4, 5, 6], + [4, 5, 6] + ] + ]; + + Tensor stacked_2 = Tensor.stack([baseTensor, baseTensor], dim: 1); + Tensor expectedStacked2 = Tensor.fromList(ld1); + expect(stacked_2.equals(expectedStacked2), isTrue); }); }