add more tensor stuff.
This commit is contained in:
parent
dffe5cee98
commit
372b3d8442
@ -11,38 +11,38 @@ Tensor randomQuantTensor(int n) {
|
||||
Tensor v = Tensor.random([n]);
|
||||
Tensor w = Tensor.random([n]);
|
||||
|
||||
Tensor a1 = u.each((f, _, __) {
|
||||
return sin(2 * pi * f);
|
||||
Tensor a1 = u.each((d) {
|
||||
return sin(2 * pi * d);
|
||||
});
|
||||
Tensor a2 = u.each((f, _, __) {
|
||||
return sqrt(1 - f);
|
||||
Tensor a2 = u.each((d) {
|
||||
return sqrt(1 - d);
|
||||
});
|
||||
Tensor a = a1 * a2;
|
||||
|
||||
Tensor b1 = v.each((f, _, __) {
|
||||
return cos(2 * pi * f);
|
||||
Tensor b1 = v.each((d) {
|
||||
return cos(2 * pi * d);
|
||||
});
|
||||
Tensor b2 = v.each((f, _, __) {
|
||||
return sqrt(1 - f);
|
||||
Tensor b2 = v.each((d) {
|
||||
return sqrt(1 - d);
|
||||
});
|
||||
|
||||
Tensor b = b1 * b2;
|
||||
|
||||
Tensor c1 = u.each((f, _, __) {
|
||||
return sqrt(f);
|
||||
Tensor c1 = u.each((d) {
|
||||
return sqrt(d);
|
||||
});
|
||||
Tensor c2 = w.each((f, _, __) {
|
||||
return sin(2 * pi * f);
|
||||
Tensor c2 = w.each((d) {
|
||||
return sin(2 * pi * d);
|
||||
});
|
||||
Tensor c = c1 * c2;
|
||||
|
||||
Tensor d1 = u.each((f, _, __) {
|
||||
return sqrt(f);
|
||||
Tensor d1 = u.each((d) {
|
||||
return sqrt(d);
|
||||
});
|
||||
Tensor d2 = w.each((f, _, __) {
|
||||
return cos(2 * pi * f);
|
||||
Tensor d2 = w.each((d) {
|
||||
return cos(2 * pi * d);
|
||||
});
|
||||
Tensor d = d1 * d2;
|
||||
|
||||
return Tensor.stack(Tensor([a[0], b[0], c[0], d[0]]));
|
||||
return Tensor.stack([a, b, c, d]);
|
||||
}
|
||||
|
@ -84,6 +84,50 @@ class Tensor {
|
||||
return data.equals(other.data);
|
||||
}
|
||||
|
||||
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments,
|
||||
// implementing pytorch's stack method.
|
||||
static Tensor stack(List<Tensor> tensors, [int dim = 0]) {
|
||||
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim.
|
||||
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim))
|
||||
? 0
|
||||
: dim;
|
||||
|
||||
List<dynamic> newShape = [
|
||||
...List<dynamic>.filled(tensors.length, 1),
|
||||
...tensors[0].shape
|
||||
];
|
||||
|
||||
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
||||
if (dim > 0) {
|
||||
for (var i = 0; i < tensors.length; i++) {
|
||||
newShape[i] *= tensors[i].shape[dim];
|
||||
}
|
||||
}
|
||||
|
||||
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
||||
if (dim > 0) {
|
||||
for (var i = 0; i < tensors.length; i++) {
|
||||
newShape[i] *= tensors[i].shape[dim];
|
||||
}
|
||||
}
|
||||
|
||||
List<dynamic> stackedData = [];
|
||||
|
||||
// Iterate through the data of each tensor and concatenate it to the stackedData list.
|
||||
for (var i = 0; i < newShape[0]; i++) {
|
||||
int currentIndex = 0;
|
||||
for (var j = 0; j < tensors.length; j++) {
|
||||
if (i >= currentIndex * tensors[j].shape[dim] &&
|
||||
i < (currentIndex + 1) * tensors[j].shape[dim]) {
|
||||
stackedData.add(tensors[j].data[currentIndex]);
|
||||
currentIndex++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Tensor.fromList(stackedData);
|
||||
}
|
||||
|
||||
operator *(Tensor other) {
|
||||
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
||||
return d1 * d2;
|
||||
|
@ -76,4 +76,25 @@ void main() {
|
||||
]);
|
||||
expect((t1 * t2).data, equals(expected.data));
|
||||
});
|
||||
|
||||
test("tensors can be stacked", () {
|
||||
List<List<double>> l = [
|
||||
[0.3367, 0.1288, 02345],
|
||||
[0.2303, -1.1229, -0.1863]
|
||||
];
|
||||
Tensor baseTensor = Tensor.fromList(l);
|
||||
Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]);
|
||||
expect(
|
||||
stacked_1.data,
|
||||
equals([
|
||||
[
|
||||
[0.3367, 0.1288, 0.2345],
|
||||
[0.2303, -1.1229, -0.1863]
|
||||
],
|
||||
[
|
||||
[0.3367, 0.1288, 0.2345],
|
||||
[0.2303, -1.1229, -0.1863]
|
||||
]
|
||||
]));
|
||||
});
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user