work more on tensors.
This commit is contained in:
parent
372b3d8442
commit
0f1b7bd5c4
@ -50,6 +50,16 @@ List<dynamic> listBroadcast(List<dynamic> l1, List<dynamic> l2,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
List<dynamic> listStack(List<List<dynamic>> lists, int dim) {
|
||||||
|
List<dynamic> newData = [];
|
||||||
|
|
||||||
|
for (int i = 0; i < lists.first.data.length; ++i) {
|
||||||
|
List<dynamic> subData = lists.map((l => l.data[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return newData;
|
||||||
|
}
|
||||||
|
|
||||||
class Tensor {
|
class Tensor {
|
||||||
List<dynamic> data = [];
|
List<dynamic> data = [];
|
||||||
List<int> shape = [];
|
List<int> shape = [];
|
||||||
@ -84,48 +94,30 @@ class Tensor {
|
|||||||
return data.equals(other.data);
|
return data.equals(other.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Static method 'stack' that takes a list of Tensors and optionally an integer `dim` as arguments,
|
static Tensor cat(List<Tensor> tensors) {
|
||||||
// implementing pytorch's stack method.
|
return Tensor.fromList(tensors.map((el) {
|
||||||
static Tensor stack(List<Tensor> tensors, [int dim = 0]) {
|
return el.data;
|
||||||
// If the provided dimension is less than 0 or greater than the number of dimensions in all Tensors, use 0 as the default value for dim.
|
}).toList());
|
||||||
dim = (dim < 0 || tensors.any((tensor) => tensor.shape.length <= dim))
|
}
|
||||||
? 0
|
|
||||||
: dim;
|
|
||||||
|
|
||||||
List<dynamic> newShape = [
|
static Tensor stack(List<Tensor> tensors, {int dim = 0}) {
|
||||||
...List<dynamic>.filled(tensors.length, 1),
|
int maxShape = tensors.map((Tensor t) {
|
||||||
...tensors[0].shape
|
return t.shape.max;
|
||||||
];
|
}).max;
|
||||||
|
if (dim < 0 || dim > maxShape) {
|
||||||
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
throw Exception("Invalid dimension");
|
||||||
if (dim > 0) {
|
|
||||||
for (var i = 0; i < tensors.length; i++) {
|
|
||||||
newShape[i] *= tensors[i].shape[dim];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
List<double> newList = [];
|
||||||
// If the provided dimension is not the first dimension of all Tensors, adjust the shape accordingly.
|
List<int> path = [];
|
||||||
if (dim > 0) {
|
for (int i = 0; i < tensors.length; i++) {
|
||||||
for (var i = 0; i < tensors.length; i++) {
|
|
||||||
newShape[i] *= tensors[i].shape[dim];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
List<dynamic> stackedData = [];
|
List<dynamic> toList({int depth = 0}) {
|
||||||
|
if (depth == 0 || ((data.length > 0 && this.data[0].runtimetype == double))) {
|
||||||
// Iterate through the data of each tensor and concatenate it to the stackedData list.
|
return data
|
||||||
for (var i = 0; i < newShape[0]; i++) {
|
|
||||||
int currentIndex = 0;
|
|
||||||
for (var j = 0; j < tensors.length; j++) {
|
|
||||||
if (i >= currentIndex * tensors[j].shape[dim] &&
|
|
||||||
i < (currentIndex + 1) * tensors[j].shape[dim]) {
|
|
||||||
stackedData.add(tensors[j].data[currentIndex]);
|
|
||||||
currentIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return data.map((d) { return d.toList(depth-1) })
|
||||||
return Tensor.fromList(stackedData);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
operator *(Tensor other) {
|
operator *(Tensor other) {
|
||||||
|
@ -79,22 +79,27 @@ void main() {
|
|||||||
|
|
||||||
test("tensors can be stacked", () {
|
test("tensors can be stacked", () {
|
||||||
List<List<double>> l = [
|
List<List<double>> l = [
|
||||||
[0.3367, 0.1288, 02345],
|
[1, 2, 3],
|
||||||
[0.2303, -1.1229, -0.1863]
|
[4, 5, 6]
|
||||||
];
|
];
|
||||||
Tensor baseTensor = Tensor.fromList(l);
|
Tensor baseTensor = Tensor.fromList(l);
|
||||||
Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]);
|
Tensor stacked_1 = Tensor.stack([baseTensor, baseTensor]);
|
||||||
expect(
|
Tensor expectedStacked1 = Tensor.fromList([l, l]);
|
||||||
stacked_1.data,
|
expect(stacked_1.equals(expectedStacked1), isTrue);
|
||||||
equals([
|
|
||||||
[
|
List<List<List<double>>> ld1 = [
|
||||||
[0.3367, 0.1288, 0.2345],
|
[
|
||||||
[0.2303, -1.1229, -0.1863]
|
[1, 2, 3],
|
||||||
],
|
[1, 2, 3],
|
||||||
[
|
],
|
||||||
[0.3367, 0.1288, 0.2345],
|
[
|
||||||
[0.2303, -1.1229, -0.1863]
|
[4, 5, 6],
|
||||||
]
|
[4, 5, 6]
|
||||||
]));
|
]
|
||||||
|
];
|
||||||
|
|
||||||
|
Tensor stacked_2 = Tensor.stack([baseTensor, baseTensor], dim: 1);
|
||||||
|
Tensor expectedStacked2 = Tensor.fromList(ld1);
|
||||||
|
expect(stacked_2.equals(expectedStacked2), isTrue);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user