add more tensor stuff.
This commit is contained in:
parent
dffe5cee98
commit
372b3d8442
@ -11,38 +11,38 @@ Tensor randomQuantTensor(int n) {
|
|||||||
Tensor v = Tensor.random([n]);
|
Tensor v = Tensor.random([n]);
|
||||||
Tensor w = Tensor.random([n]);
|
Tensor w = Tensor.random([n]);
|
||||||
|
|
||||||
Tensor a1 = u.each((f, _, __) {
|
Tensor a1 = u.each((d) {
|
||||||
return sin(2 * pi * f);
|
return sin(2 * pi * d);
|
||||||
});
|
});
|
||||||
Tensor a2 = u.each((f, _, __) {
|
Tensor a2 = u.each((d) {
|
||||||
return sqrt(1 - f);
|
return sqrt(1 - d);
|
||||||
});
|
});
|
||||||
Tensor a = a1 * a2;
|
Tensor a = a1 * a2;
|
||||||
|
|
||||||
Tensor b1 = v.each((f, _, __) {
|
Tensor b1 = v.each((d) {
|
||||||
return cos(2 * pi * f);
|
return cos(2 * pi * d);
|
||||||
});
|
});
|
||||||
Tensor b2 = v.each((f, _, __) {
|
Tensor b2 = v.each((d) {
|
||||||
return sqrt(1 - f);
|
return sqrt(1 - d);
|
||||||
});
|
});
|
||||||
|
|
||||||
Tensor b = b1 * b2;
|
Tensor b = b1 * b2;
|
||||||
|
|
||||||
Tensor c1 = u.each((f, _, __) {
|
Tensor c1 = u.each((d) {
|
||||||
return sqrt(f);
|
return sqrt(d);
|
||||||
});
|
});
|
||||||
Tensor c2 = w.each((f, _, __) {
|
Tensor c2 = w.each((d) {
|
||||||
return sin(2 * pi * f);
|
return sin(2 * pi * d);
|
||||||
});
|
});
|
||||||
Tensor c = c1 * c2;
|
Tensor c = c1 * c2;
|
||||||
|
|
||||||
Tensor d1 = u.each((f, _, __) {
|
Tensor d1 = u.each((d) {
|
||||||
return sqrt(f);
|
return sqrt(d);
|
||||||
});
|
});
|
||||||
Tensor d2 = w.each((f, _, __) {
|
Tensor d2 = w.each((d) {
|
||||||
return cos(2 * pi * f);
|
return cos(2 * pi * d);
|
||||||
});
|
});
|
||||||
Tensor d = d1 * d2;
|
Tensor d = d1 * d2;
|
||||||
|
|
||||||
return Tensor.stack(Tensor([a[0], b[0], c[0], d[0]]));
|
return Tensor.stack([a, b, c, d]);
|
||||||
}
|
}
|
||||||
|
@ -84,6 +84,50 @@ class Tensor {
|
|||||||
return data.equals(other.data);
|
return data.equals(other.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments,
|
||||||
|
// implementing pytorch's stack method.
|
||||||
|
static Tensor stack(List<Tensor> tensors, [int dim = 0]) {
|
||||||
|
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim.
|
||||||
|
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim))
|
||||||
|
? 0
|
||||||
|
: dim;
|
||||||
|
|
||||||
|
List<dynamic> newShape = [
|
||||||
|
...List<dynamic>.filled(tensors.length, 1),
|
||||||
|
...tensors[0].shape
|
||||||
|
];
|
||||||
|
|
||||||
|
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
||||||
|
if (dim > 0) {
|
||||||
|
for (var i = 0; i < tensors.length; i++) {
|
||||||
|
newShape[i] *= tensors[i].shape[dim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
||||||
|
if (dim > 0) {
|
||||||
|
for (var i = 0; i < tensors.length; i++) {
|
||||||
|
newShape[i] *= tensors[i].shape[dim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
List<dynamic> stackedData = [];
|
||||||
|
|
||||||
|
// Iterate through the data of each tensor and concatenate it to the stackedData list.
|
||||||
|
for (var i = 0; i < newShape[0]; i++) {
|
||||||
|
int currentIndex = 0;
|
||||||
|
for (var j = 0; j < tensors.length; j++) {
|
||||||
|
if (i >= currentIndex * tensors[j].shape[dim] &&
|
||||||
|
i < (currentIndex + 1) * tensors[j].shape[dim]) {
|
||||||
|
stackedData.add(tensors[j].data[currentIndex]);
|
||||||
|
currentIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tensor.fromList(stackedData);
|
||||||
|
}
|
||||||
|
|
||||||
operator *(Tensor other) {
|
operator *(Tensor other) {
|
||||||
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
||||||
return d1 * d2;
|
return d1 * d2;
|
||||||
|
@ -76,4 +76,25 @@ void main() {
|
|||||||
]);
|
]);
|
||||||
expect((t1 * t2).data, equals(expected.data));
|
expect((t1 * t2).data, equals(expected.data));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("tensors can be stacked", () {
|
||||||
|
List<List<double>> l = [
|
||||||
|
[0.3367, 0.1288, 02345],
|
||||||
|
[0.2303, -1.1229, -0.1863]
|
||||||
|
];
|
||||||
|
Tensor baseTensor = Tensor.fromList(l);
|
||||||
|
Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]);
|
||||||
|
expect(
|
||||||
|
stacked_1.data,
|
||||||
|
equals([
|
||||||
|
[
|
||||||
|
[0.3367, 0.1288, 0.2345],
|
||||||
|
[0.2303, -1.1229, -0.1863]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0.3367, 0.1288, 0.2345],
|
||||||
|
[0.2303, -1.1229, -0.1863]
|
||||||
|
]
|
||||||
|
]));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user