From 372b3d84429342fdac8cfc9f2dad0a8229ef1401 Mon Sep 17 00:00:00 2001 From: Jordan Date: Thu, 19 Dec 2024 06:20:19 -0800 Subject: [PATCH] add more tensor stuff. --- lib/src/splat/model.dart | 34 +++++++++++----------- lib/src/splat/tensor.dart | 44 +++++++++++++++++++++++++++++ lib/src/test/splat/tensor_test.dart | 21 ++++++++++++++ 3 files changed, 82 insertions(+), 17 deletions(-) diff --git a/lib/src/splat/model.dart b/lib/src/splat/model.dart index a026b95..877fb2d 100644 --- a/lib/src/splat/model.dart +++ b/lib/src/splat/model.dart @@ -11,38 +11,38 @@ Tensor randomQuantTensor(int n) { Tensor v = Tensor.random([n]); Tensor w = Tensor.random([n]); - Tensor a1 = u.each((f, _, __) { - return sin(2 * pi * f); + Tensor a1 = u.each((d) { + return sin(2 * pi * d); }); - Tensor a2 = u.each((f, _, __) { - return sqrt(1 - f); + Tensor a2 = u.each((d) { + return sqrt(1 - d); }); Tensor a = a1 * a2; - Tensor b1 = v.each((f, _, __) { - return cos(2 * pi * f); + Tensor b1 = v.each((d) { + return cos(2 * pi * d); }); - Tensor b2 = v.each((f, _, __) { - return sqrt(1 - f); + Tensor b2 = v.each((d) { + return sqrt(1 - d); }); Tensor b = b1 * b2; - Tensor c1 = u.each((f, _, __) { - return sqrt(f); + Tensor c1 = u.each((d) { + return sqrt(d); }); - Tensor c2 = w.each((f, _, __) { - return sin(2 * pi * f); + Tensor c2 = w.each((d) { + return sin(2 * pi * d); }); Tensor c = c1 * c2; - Tensor d1 = u.each((f, _, __) { - return sqrt(f); + Tensor d1 = u.each((d) { + return sqrt(d); }); - Tensor d2 = w.each((f, _, __) { - return cos(2 * pi * f); + Tensor d2 = w.each((d) { + return cos(2 * pi * d); }); Tensor d = d1 * d2; - return Tensor.stack(Tensor([a[0], b[0], c[0], d[0]])); + return Tensor.stack([a, b, c, d]); } diff --git a/lib/src/splat/tensor.dart b/lib/src/splat/tensor.dart index 46c1efb..fa1a8b7 100644 --- a/lib/src/splat/tensor.dart +++ b/lib/src/splat/tensor.dart @@ -84,6 +84,50 @@ class Tensor { return data.equals(other.data); } +// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments, +// implementing pytorch's stack method. + static Tensor stack(List tensors, [int dim = 0]) { + // If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim. + dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim)) + ? 0 + : dim; + + List newShape = [ + ...List.filled(tensors.length, 1), + ...tensors[0].shape + ]; + + // If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. + if (dim > 0) { + for (var i = 0; i < tensors.length; i++) { + newShape[i] *= tensors[i].shape[dim]; + } + } + + // If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly. + if (dim > 0) { + for (var i = 0; i < tensors.length; i++) { + newShape[i] *= tensors[i].shape[dim]; + } + } + + List stackedData = []; + + // Iterate through the data of each tensor and concatenate it to the stackedData list. + for (var i = 0; i < newShape[0]; i++) { + int currentIndex = 0; + for (var j = 0; j < tensors.length; j++) { + if (i >= currentIndex * tensors[j].shape[dim] && + i < (currentIndex + 1) * tensors[j].shape[dim]) { + stackedData.add(tensors[j].data[currentIndex]); + currentIndex++; + } + } + } + + return Tensor.fromList(stackedData); + } + operator *(Tensor other) { return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) { return d1 * d2; diff --git a/lib/src/test/splat/tensor_test.dart b/lib/src/test/splat/tensor_test.dart index 07974ba..48dec92 100644 --- a/lib/src/test/splat/tensor_test.dart +++ b/lib/src/test/splat/tensor_test.dart @@ -76,4 +76,25 @@ void main() { ]); expect((t1 * t2).data, equals(expected.data)); }); + + test("tensors can be stacked", () { + List> l = [ + [0.3367, 0.1288, 02345], + [0.2303, -1.1229, -0.1863] + ]; + Tensor baseTensor = Tensor.fromList(l); + Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]); + expect( + stacked_1.data, + equals([ + [ + [0.3367, 0.1288, 0.2345], + [0.2303, -1.1229, -0.1863] + ], + [ + [0.3367, 0.1288, 0.2345], + [0.2303, -1.1229, -0.1863] + ] + ])); + }); }