From 6fcff57d685e8ca6e90658be265b63d554c11cb7 Mon Sep 17 00:00:00 2001 From: Jordan Hewitt Date: Fri, 13 Dec 2024 17:18:12 -0800 Subject: [PATCH] add each and broadcast methods. --- lib/src/splat/tensor.dart | 145 ++++++++++++++-------------- lib/src/test/splat/tensor_test.dart | 89 ++++++++++++----- 2 files changed, 137 insertions(+), 97 deletions(-) diff --git a/lib/src/splat/tensor.dart b/lib/src/splat/tensor.dart index e02efc8..790570b 100644 --- a/lib/src/splat/tensor.dart +++ b/lib/src/splat/tensor.dart @@ -2,88 +2,85 @@ import 'dart:math'; import 'package:collection/collection.dart'; -class Tensor extends DelegatingList> { - Tensor(super.base); - - factory Tensor.fromList(List> lst) { - return Tensor(DelegatingList(lst.map((e) => DelegatingList(e)).toList())); - } - - Tensor each(Function(double, int, int) f) { - Tensor other = Tensor([]); - for (int j = 0; j < length; ++j) { - other[j] = const DelegatingList([]); - for (int k = 0; k < this[j].length; ++k) { - other[j][k] = f(this[j][k], j, k); - } +List dimensionalList(List shape, + {dynamic fillValue = 0.0, dynamic Function(int)? generator}) { + if (shape.length == 1) { + if (generator != null) { + return List.generate(shape[0], generator); } - return other; + return List.filled(shape[0], fillValue); + } + return List.generate(shape[0], (int i) { + return dimensionalList(shape.sublist(1), + fillValue: fillValue, generator: generator); + }); +} + +List detectShape(List data) { + if (data.runtimeType != List) { + return [data.length]; + } + return [data.length] + detectShape(data[0]); +} + +List deepApply( + List data, dynamic Function(dynamic) callback) { + if (data[0].runtimeType != List) { + return data.map((d) { + return callback(d); + }).toList(); + } + return data.map((d) { + return deepApply(d, callback); + }).toList(); +} + +List listBroadcast( + dynamic l1, dynamic l2, dynamic Function(dynamic, dynamic) callback) { + if (!(l1.runtimeType == List && l2.runtimeType == List)) { + return callback(l1, l2); + } + if (!detectShape(l1).equals(detectShape(l2))) { + throw Exception("l1 != l2"); + } + return List.generate(l1.length, (int i) { + return listBroadcast(l1[i], l2[i], callback); + }); +} + +class Tensor { + List data = []; + List shape = []; + + Tensor(this.shape, this.data); + + factory Tensor.fromShape(List shape) { + return Tensor(shape, dimensionalList(shape)); } - /// Generate a random tensor of shape `shape` - static Tensor random(List shape) { + factory Tensor.fromList(List data) { + return Tensor(detectShape(data), data); + } + + factory Tensor.generate(List shape, dynamic Function(int)? generator) { + return Tensor( + shape, dimensionalList(shape, fillValue: 0.0, generator: generator)); + } + + factory Tensor.random(List shape) { Random r = Random(); - - int d1 = 0, d2 = 0; - if (shape.length == 1) { - d1 = shape[0]; - } else if (shape.length == 2) { - d1 = shape[0]; - d2 = shape[1]; - } else if (shape.length == 3) { - // d3 = shapes[2]; - } - Tensor ret = Tensor(List.filled(d1, DelegatingList(List.filled(d2, 0.0)))); - - for (int i = 0; i < d1; ++i) { - for (int j = 0; j < d2; ++j) { - ret[i][j] = r.nextDouble(); - } - } - - return ret; + return Tensor.generate(shape, (int _) { + return r.nextDouble(); + }); } - static Tensor stack(Tensor tensors, {int axis = 0}) { - if (axis < -1 || axis > tensors.length - 1) { - throw ArgumentError('Invalid axis value'); - } - - int newAxisSize = tensors.length; - for (var tensor in tensors) { - newAxisSize *= tensor.length; - } - - Tensor result = Tensor([]); - for (int i = 0; i < newAxisSize; i++) { - int index = i; - int currentAxisIndex = axis; - List currentAxisIndexes = List.filled(tensors.length, -1); - int currentTensorIndex = 0; - - while (currentAxisIndexes[currentTensorIndex] < tensors.length) { - if (currentAxisIndexes[currentTensorIndex] == currentAxisIndex) { - index = currentAxisIndexes[currentTensorIndex] + - (index ~/ tensors.length); - currentAxisIndexes[currentTensorIndex]++; - } - currentAxisIndex += (axis > 0 ? 1 : -1); - currentTensorIndex = - (currentAxisIndex < 0 || currentTensorIndex >= tensors.length) - ? 0 - : currentTensorIndex + 1; - } - - result.add(tensors[ - currentTensorIndex]); // Access tensors[currentTensorIndex] as a List rather than using the index operator [] with it - } - - return Tensor(result); + Tensor each(dynamic Function(dynamic) callback) { + return Tensor.fromList(deepApply(data, callback)); } operator *(Tensor other) { - return each((d, i, j) { - return [i][j] * other[i][j]; - }); + return Tensor.fromList(listBroadcast(data, other.data, (d1, d2) { + return d1 * d2; + })); } } diff --git a/lib/src/test/splat/tensor_test.dart b/lib/src/test/splat/tensor_test.dart index d83952f..52d8517 100644 --- a/lib/src/test/splat/tensor_test.dart +++ b/lib/src/test/splat/tensor_test.dart @@ -1,36 +1,79 @@ +import 'dart:math'; + import 'package:archimedes_test/src/splat/tensor.dart'; import 'package:collection/collection.dart'; import 'package:test/test.dart'; void main() { - group('stack', () { - test('A random tensor can be generated', () { - Tensor generated = Tensor.random([2, 4]); - expect(generated.length, equals(2)); - expect(generated[0].length, equals(4)); - }); + group("tensor", () { + test("Tensors can be constructed from a shape", () { + Tensor t1 = Tensor.fromShape([2]); - test("A tensor can be stacked", () { - Tensor x = Tensor(const DelegatingList([ - DelegatingList([0.3367, 0.1288, 0.2345]), - DelegatingList([0.2303, -1.1229, -0.1863]) - ])); + expect(t1.data, equals([0, 0])); + + Tensor t2 = Tensor.fromShape([2, 3, 4]); - expect(Tensor.stack(x)[0], equals(x)); expect( - Tensor.stack(x), - equals(Tensor.fromList([ + t2.data, + equals([ [ - [ - [0.3367, 0.1288, 0.2345], - [0.3367, 0.1288, 0.2345], - ], - [ - [0.2303, -1.1229, -0.1863], - [0.2303, -1.1229, -0.1863] - ] + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0] + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0] ] - ]))); + ])); }); }); + test("Tensors can be constructed from a list", () { + Tensor t1 = Tensor.fromList([ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ]); + + expect(t1.shape, equals([2, 3])); + }); + + test("Tensor can be generated", () { + Random r = Random(); + Tensor t = Tensor.generate([1, 5], (int i) { + return r.nextDouble(); + }); + expect(t.data, isNotNull); + }); + + test("Function can be applied", () { + Random r = Random(); + Tensor t = Tensor.generate([1, 2], (int i) { + return r.nextDouble(); + }); + + Tensor t2 = t.each((dynamic d) { + return d * 100.0; + }); + + for (int i = 0; i < t2.data.length; ++i) { + expect(0 < t2.data[0][i] && t2.data[0][i] < 100, isTrue); + } + }); + + test("Can broadcast 2 tensors", () { + Tensor t1 = Tensor.fromList([ + [1, 2, 3], + [4, 5, 6] + ]); + Tensor t2 = Tensor.fromList([ + [2, 2, 2], + [2, 2, 2] + ]); + Tensor expected = Tensor.fromList([ + [2, 4, 6], + [8, 10, 12] + ]); + expect(t1 * t2, equals(expected)); + }); }