add more tensor stuff.

This commit is contained in:
Jordan 2024-12-19 06:20:19 -08:00
parent dffe5cee98
commit 372b3d8442
3 changed files with 82 additions and 17 deletions

View File

@ -11,38 +11,38 @@ Tensor randomQuantTensor(int n) {
Tensor v = Tensor.random([n]);
Tensor w = Tensor.random([n]);
Tensor a1 = u.each((f, _, __) {
return sin(2 * pi * f);
Tensor a1 = u.each((d) {
return sin(2 * pi * d);
});
Tensor a2 = u.each((f, _, __) {
return sqrt(1 - f);
Tensor a2 = u.each((d) {
return sqrt(1 - d);
});
Tensor a = a1 * a2;
Tensor b1 = v.each((f, _, __) {
return cos(2 * pi * f);
Tensor b1 = v.each((d) {
return cos(2 * pi * d);
});
Tensor b2 = v.each((f, _, __) {
return sqrt(1 - f);
Tensor b2 = v.each((d) {
return sqrt(1 - d);
});
Tensor b = b1 * b2;
Tensor c1 = u.each((f, _, __) {
return sqrt(f);
Tensor c1 = u.each((d) {
return sqrt(d);
});
Tensor c2 = w.each((f, _, __) {
return sin(2 * pi * f);
Tensor c2 = w.each((d) {
return sin(2 * pi * d);
});
Tensor c = c1 * c2;
Tensor d1 = u.each((f, _, __) {
return sqrt(f);
Tensor d1 = u.each((d) {
return sqrt(d);
});
Tensor d2 = w.each((f, _, __) {
return cos(2 * pi * f);
Tensor d2 = w.each((d) {
return cos(2 * pi * d);
});
Tensor d = d1 * d2;
return Tensor.stack(Tensor([a[0], b[0], c[0], d[0]]));
return Tensor.stack([a, b, c, d]);
}

View File

@ -84,6 +84,50 @@ class Tensor {
return data.equals(other.data);
}
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments,
// implementing pytorch's stack method.
static Tensor stack(List<Tensor> tensors, [int dim = 0]) {
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim.
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim))
? 0
: dim;
List<dynamic> newShape = [
...List<dynamic>.filled(tensors.length, 1),
...tensors[0].shape
];
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
if (dim > 0) {
for (var i = 0; i < tensors.length; i++) {
newShape[i] *= tensors[i].shape[dim];
}
}
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
if (dim > 0) {
for (var i = 0; i < tensors.length; i++) {
newShape[i] *= tensors[i].shape[dim];
}
}
List<dynamic> stackedData = [];
// Iterate through the data of each tensor and concatenate it to the stackedData list.
for (var i = 0; i < newShape[0]; i++) {
int currentIndex = 0;
for (var j = 0; j < tensors.length; j++) {
if (i >= currentIndex * tensors[j].shape[dim] &&
i < (currentIndex + 1) * tensors[j].shape[dim]) {
stackedData.add(tensors[j].data[currentIndex]);
currentIndex++;
}
}
}
return Tensor.fromList(stackedData);
}
operator *(Tensor other) {
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
return d1 * d2;

View File

@ -76,4 +76,25 @@ void main() {
]);
expect((t1 * t2).data, equals(expected.data));
});
test("tensors can be stacked", () {
List<List<double>> l = [
[0.3367, 0.1288, 02345],
[0.2303, -1.1229, -0.1863]
];
Tensor baseTensor = Tensor.fromList(l);
Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]);
expect(
stacked_1.data,
equals([
[
[0.3367, 0.1288, 0.2345],
[0.2303, -1.1229, -0.1863]
],
[
[0.3367, 0.1288, 0.2345],
[0.2303, -1.1229, -0.1863]
]
]));
});
}