add each and broadcast methods.
This commit is contained in:
@ -2,88 +2,85 @@ import 'dart:math';
|
||||
|
||||
import 'package:collection/collection.dart';
|
||||
|
||||
class Tensor extends DelegatingList<DelegatingList<double>> {
|
||||
Tensor(super.base);
|
||||
|
||||
factory Tensor.fromList(List<List<double>> lst) {
|
||||
return Tensor(DelegatingList(lst.map((e) => DelegatingList(e)).toList()));
|
||||
}
|
||||
|
||||
Tensor each(Function(double, int, int) f) {
|
||||
Tensor other = Tensor([]);
|
||||
for (int j = 0; j < length; ++j) {
|
||||
other[j] = const DelegatingList([]);
|
||||
for (int k = 0; k < this[j].length; ++k) {
|
||||
other[j][k] = f(this[j][k], j, k);
|
||||
}
|
||||
List<dynamic> dimensionalList(List<int> shape,
|
||||
{dynamic fillValue = 0.0, dynamic Function(int)? generator}) {
|
||||
if (shape.length == 1) {
|
||||
if (generator != null) {
|
||||
return List.generate(shape[0], generator);
|
||||
}
|
||||
return other;
|
||||
return List.filled(shape[0], fillValue);
|
||||
}
|
||||
return List.generate(shape[0], (int i) {
|
||||
return dimensionalList(shape.sublist(1),
|
||||
fillValue: fillValue, generator: generator);
|
||||
});
|
||||
}
|
||||
|
||||
List<int> detectShape(List<dynamic> data) {
|
||||
if (data.runtimeType != List) {
|
||||
return [data.length];
|
||||
}
|
||||
return [data.length] + detectShape(data[0]);
|
||||
}
|
||||
|
||||
List<dynamic> deepApply(
|
||||
List<dynamic> data, dynamic Function(dynamic) callback) {
|
||||
if (data[0].runtimeType != List) {
|
||||
return data.map((d) {
|
||||
return callback(d);
|
||||
}).toList();
|
||||
}
|
||||
return data.map((d) {
|
||||
return deepApply(d, callback);
|
||||
}).toList();
|
||||
}
|
||||
|
||||
List<dynamic> listBroadcast(
|
||||
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) {
|
||||
if (!(l1.runtimeType == List && l2.runtimeType == List)) {
|
||||
return callback(l1, l2);
|
||||
}
|
||||
if (!detectShape(l1).equals(detectShape(l2))) {
|
||||
throw Exception("l1 != l2");
|
||||
}
|
||||
return List.generate(l1.length, (int i) {
|
||||
return listBroadcast(l1[i], l2[i], callback);
|
||||
});
|
||||
}
|
||||
|
||||
class Tensor {
|
||||
List<dynamic> data = [];
|
||||
List<int> shape = [];
|
||||
|
||||
Tensor(this.shape, this.data);
|
||||
|
||||
factory Tensor.fromShape(List<int> shape) {
|
||||
return Tensor(shape, dimensionalList(shape));
|
||||
}
|
||||
|
||||
/// Generate a random tensor of shape `shape`
|
||||
static Tensor random(List<int> shape) {
|
||||
factory Tensor.fromList(List<dynamic> data) {
|
||||
return Tensor(detectShape(data), data);
|
||||
}
|
||||
|
||||
factory Tensor.generate(List<int> shape, dynamic Function(int)? generator) {
|
||||
return Tensor(
|
||||
shape, dimensionalList(shape, fillValue: 0.0, generator: generator));
|
||||
}
|
||||
|
||||
factory Tensor.random(List<int> shape) {
|
||||
Random r = Random();
|
||||
|
||||
int d1 = 0, d2 = 0;
|
||||
if (shape.length == 1) {
|
||||
d1 = shape[0];
|
||||
} else if (shape.length == 2) {
|
||||
d1 = shape[0];
|
||||
d2 = shape[1];
|
||||
} else if (shape.length == 3) {
|
||||
// d3 = shapes[2];
|
||||
}
|
||||
Tensor ret = Tensor(List.filled(d1, DelegatingList(List.filled(d2, 0.0))));
|
||||
|
||||
for (int i = 0; i < d1; ++i) {
|
||||
for (int j = 0; j < d2; ++j) {
|
||||
ret[i][j] = r.nextDouble();
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
return Tensor.generate(shape, (int _) {
|
||||
return r.nextDouble();
|
||||
});
|
||||
}
|
||||
|
||||
static Tensor stack(Tensor tensors, {int axis = 0}) {
|
||||
if (axis < -1 || axis > tensors.length - 1) {
|
||||
throw ArgumentError('Invalid axis value');
|
||||
}
|
||||
|
||||
int newAxisSize = tensors.length;
|
||||
for (var tensor in tensors) {
|
||||
newAxisSize *= tensor.length;
|
||||
}
|
||||
|
||||
Tensor result = Tensor([]);
|
||||
for (int i = 0; i < newAxisSize; i++) {
|
||||
int index = i;
|
||||
int currentAxisIndex = axis;
|
||||
List<int> currentAxisIndexes = List.filled(tensors.length, -1);
|
||||
int currentTensorIndex = 0;
|
||||
|
||||
while (currentAxisIndexes[currentTensorIndex] < tensors.length) {
|
||||
if (currentAxisIndexes[currentTensorIndex] == currentAxisIndex) {
|
||||
index = currentAxisIndexes[currentTensorIndex] +
|
||||
(index ~/ tensors.length);
|
||||
currentAxisIndexes[currentTensorIndex]++;
|
||||
}
|
||||
currentAxisIndex += (axis > 0 ? 1 : -1);
|
||||
currentTensorIndex =
|
||||
(currentAxisIndex < 0 || currentTensorIndex >= tensors.length)
|
||||
? 0
|
||||
: currentTensorIndex + 1;
|
||||
}
|
||||
|
||||
result.add(tensors[
|
||||
currentTensorIndex]); // Access tensors[currentTensorIndex] as a List<double> rather than using the index operator [] with it
|
||||
}
|
||||
|
||||
return Tensor(result);
|
||||
Tensor each(dynamic Function(dynamic) callback) {
|
||||
return Tensor.fromList(deepApply(data, callback));
|
||||
}
|
||||
|
||||
operator *(Tensor other) {
|
||||
return each((d, i, j) {
|
||||
return [i][j] * other[i][j];
|
||||
});
|
||||
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
||||
return d1 * d2;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user