add each and broadcast methods.
This commit is contained in:
parent
859f540f4a
commit
6fcff57d68
@ -2,88 +2,85 @@ import 'dart:math';
|
|||||||
|
|
||||||
import 'package:collection/collection.dart';
|
import 'package:collection/collection.dart';
|
||||||
|
|
||||||
class Tensor extends DelegatingList<DelegatingList<double>> {
|
List<dynamic> dimensionalList(List<int> shape,
|
||||||
Tensor(super.base);
|
{dynamic fillValue = 0.0, dynamic Function(int)? generator}) {
|
||||||
|
|
||||||
factory Tensor.fromList(List<List<double>> lst) {
|
|
||||||
return Tensor(DelegatingList(lst.map((e) => DelegatingList(e)).toList()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor each(Function(double, int, int) f) {
|
|
||||||
Tensor other = Tensor([]);
|
|
||||||
for (int j = 0; j < length; ++j) {
|
|
||||||
other[j] = const DelegatingList([]);
|
|
||||||
for (int k = 0; k < this[j].length; ++k) {
|
|
||||||
other[j][k] = f(this[j][k], j, k);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return other;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a random tensor of shape `shape`
|
|
||||||
static Tensor random(List<int> shape) {
|
|
||||||
Random r = Random();
|
|
||||||
|
|
||||||
int d1 = 0, d2 = 0;
|
|
||||||
if (shape.length == 1) {
|
if (shape.length == 1) {
|
||||||
d1 = shape[0];
|
if (generator != null) {
|
||||||
} else if (shape.length == 2) {
|
return List.generate(shape[0], generator);
|
||||||
d1 = shape[0];
|
|
||||||
d2 = shape[1];
|
|
||||||
} else if (shape.length == 3) {
|
|
||||||
// d3 = shapes[2];
|
|
||||||
}
|
}
|
||||||
Tensor ret = Tensor(List.filled(d1, DelegatingList(List.filled(d2, 0.0))));
|
return List.filled(shape[0], fillValue);
|
||||||
|
|
||||||
for (int i = 0; i < d1; ++i) {
|
|
||||||
for (int j = 0; j < d2; ++j) {
|
|
||||||
ret[i][j] = r.nextDouble();
|
|
||||||
}
|
}
|
||||||
|
return List.generate(shape[0], (int i) {
|
||||||
|
return dimensionalList(shape.sublist(1),
|
||||||
|
fillValue: fillValue, generator: generator);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
List<int> detectShape(List<dynamic> data) {
|
||||||
|
if (data.runtimeType != List) {
|
||||||
|
return [data.length];
|
||||||
|
}
|
||||||
|
return [data.length] + detectShape(data[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Tensor stack(Tensor tensors, {int axis = 0}) {
|
List<dynamic> deepApply(
|
||||||
if (axis < -1 || axis > tensors.length - 1) {
|
List<dynamic> data, dynamic Function(dynamic) callback) {
|
||||||
throw ArgumentError('Invalid axis value');
|
if (data[0].runtimeType != List) {
|
||||||
|
return data.map((d) {
|
||||||
|
return callback(d);
|
||||||
|
}).toList();
|
||||||
|
}
|
||||||
|
return data.map((d) {
|
||||||
|
return deepApply(d, callback);
|
||||||
|
}).toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
int newAxisSize = tensors.length;
|
List<dynamic> listBroadcast(
|
||||||
for (var tensor in tensors) {
|
dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) {
|
||||||
newAxisSize *= tensor.length;
|
if (!(l1.runtimeType == List && l2.runtimeType == List)) {
|
||||||
|
return callback(l1, l2);
|
||||||
|
}
|
||||||
|
if (!detectShape(l1).equals(detectShape(l2))) {
|
||||||
|
throw Exception("l1 != l2");
|
||||||
|
}
|
||||||
|
return List.generate(l1.length, (int i) {
|
||||||
|
return listBroadcast(l1[i], l2[i], callback);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor result = Tensor([]);
|
class Tensor {
|
||||||
for (int i = 0; i < newAxisSize; i++) {
|
List<dynamic> data = [];
|
||||||
int index = i;
|
List<int> shape = [];
|
||||||
int currentAxisIndex = axis;
|
|
||||||
List<int> currentAxisIndexes = List.filled(tensors.length, -1);
|
|
||||||
int currentTensorIndex = 0;
|
|
||||||
|
|
||||||
while (currentAxisIndexes[currentTensorIndex] < tensors.length) {
|
Tensor(this.shape, this.data);
|
||||||
if (currentAxisIndexes[currentTensorIndex] == currentAxisIndex) {
|
|
||||||
index = currentAxisIndexes[currentTensorIndex] +
|
factory Tensor.fromShape(List<int> shape) {
|
||||||
(index ~/ tensors.length);
|
return Tensor(shape, dimensionalList(shape));
|
||||||
currentAxisIndexes[currentTensorIndex]++;
|
|
||||||
}
|
|
||||||
currentAxisIndex += (axis > 0 ? 1 : -1);
|
|
||||||
currentTensorIndex =
|
|
||||||
(currentAxisIndex < 0 || currentTensorIndex >= tensors.length)
|
|
||||||
? 0
|
|
||||||
: currentTensorIndex + 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result.add(tensors[
|
factory Tensor.fromList(List<dynamic> data) {
|
||||||
currentTensorIndex]); // Access tensors[currentTensorIndex] as a List<double> rather than using the index operator [] with it
|
return Tensor(detectShape(data), data);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Tensor(result);
|
factory Tensor.generate(List<int> shape, dynamic Function(int)? generator) {
|
||||||
|
return Tensor(
|
||||||
|
shape, dimensionalList(shape, fillValue: 0.0, generator: generator));
|
||||||
|
}
|
||||||
|
|
||||||
|
factory Tensor.random(List<int> shape) {
|
||||||
|
Random r = Random();
|
||||||
|
return Tensor.generate(shape, (int _) {
|
||||||
|
return r.nextDouble();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor each(dynamic Function(dynamic) callback) {
|
||||||
|
return Tensor.fromList(deepApply(data, callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
operator *(Tensor other) {
|
operator *(Tensor other) {
|
||||||
return each((d, i, j) {
|
return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) {
|
||||||
return [i][j] * other[i][j];
|
return d1 * d2;
|
||||||
});
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,36 +1,79 @@
|
|||||||
|
import 'dart:math';
|
||||||
|
|
||||||
import 'package:archimedes_test/src/splat/tensor.dart';
|
import 'package:archimedes_test/src/splat/tensor.dart';
|
||||||
import 'package:collection/collection.dart';
|
import 'package:collection/collection.dart';
|
||||||
import 'package:test/test.dart';
|
import 'package:test/test.dart';
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
group('stack', () {
|
group("tensor", () {
|
||||||
test('A random tensor can be generated', () {
|
test("Tensors can be constructed from a shape", () {
|
||||||
Tensor generated = Tensor.random([2, 4]);
|
Tensor t1 = Tensor.fromShape([2]);
|
||||||
expect(generated.length, equals(2));
|
|
||||||
expect(generated[0].length, equals(4));
|
|
||||||
});
|
|
||||||
|
|
||||||
test("A tensor can be stacked", () {
|
expect(t1.data, equals([0, 0]));
|
||||||
Tensor x = Tensor(const DelegatingList([
|
|
||||||
DelegatingList([0.3367, 0.1288, 0.2345]),
|
Tensor t2 = Tensor.fromShape([2, 3, 4]);
|
||||||
DelegatingList([0.2303, -1.1229, -0.1863])
|
|
||||||
]));
|
|
||||||
|
|
||||||
expect(Tensor.stack(x)[0], equals(x));
|
|
||||||
expect(
|
expect(
|
||||||
Tensor.stack(x),
|
t2.data,
|
||||||
equals(Tensor.fromList([
|
equals([
|
||||||
[
|
[
|
||||||
[
|
[0.0, 0.0, 0.0, 0.0],
|
||||||
[0.3367, 0.1288, 0.2345],
|
[0.0, 0.0, 0.0, 0.0],
|
||||||
[0.3367, 0.1288, 0.2345],
|
[0.0, 0.0, 0.0, 0.0]
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
[0.2303, -1.1229, -0.1863],
|
[0.0, 0.0, 0.0, 0.0],
|
||||||
[0.2303, -1.1229, -0.1863]
|
[0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0]
|
||||||
]
|
]
|
||||||
]
|
]));
|
||||||
])));
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
test("Tensors can be constructed from a list", () {
|
||||||
|
Tensor t1 = Tensor.fromList([
|
||||||
|
[1.0, 2.0, 3.0],
|
||||||
|
[1.0, 2.0, 3.0]
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(t1.shape, equals([2, 3]));
|
||||||
|
});
|
||||||
|
|
||||||
|
test("Tensor can be generated", () {
|
||||||
|
Random r = Random();
|
||||||
|
Tensor t = Tensor.generate([1, 5], (int i) {
|
||||||
|
return r.nextDouble();
|
||||||
|
});
|
||||||
|
expect(t.data, isNotNull);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("Function can be applied", () {
|
||||||
|
Random r = Random();
|
||||||
|
Tensor t = Tensor.generate([1, 2], (int i) {
|
||||||
|
return r.nextDouble();
|
||||||
|
});
|
||||||
|
|
||||||
|
Tensor t2 = t.each((dynamic d) {
|
||||||
|
return d * 100.0;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (int i = 0; i < t2.data.length; ++i) {
|
||||||
|
expect(0 < t2.data[0][i] && t2.data[0][i] < 100, isTrue);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test("Can broadcast 2 tensors", () {
|
||||||
|
Tensor t1 = Tensor.fromList([
|
||||||
|
[1, 2, 3],
|
||||||
|
[4, 5, 6]
|
||||||
|
]);
|
||||||
|
Tensor t2 = Tensor.fromList([
|
||||||
|
[2, 2, 2],
|
||||||
|
[2, 2, 2]
|
||||||
|
]);
|
||||||
|
Tensor expected = Tensor.fromList([
|
||||||
|
[2, 4, 6],
|
||||||
|
[8, 10, 12]
|
||||||
|
]);
|
||||||
|
expect(t1 * t2, equals(expected));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user