From 3e7e4e0aed5f9375612d582911674bd6e44e1e58 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Wed, 1 Jul 2026 07:08:03 +0200 Subject: [PATCH] add api doc for all tensorops --- .../scala/dimwit/autodiff/TensorTree.scala | 6 +- .../src/main/scala/dimwit/tensor/Tensor.scala | 3 +- .../tensor/tensorops/ContractionOps.scala | 5 +- .../tensor/tensorops/ConvolutionOps.scala | 55 +++- .../tensor/tensorops/ElementWiseOps.scala | 64 +++- .../tensor/tensorops/FunctionalOps.scala | 67 ++++- .../tensor/tensorops/LinearAlgebraOps.scala | 60 +++- .../tensor/tensorops/ReductionOps.scala | 66 ++-- .../tensor/tensorops/StructuralOps.scala | 284 ++++++++++++------ .../dimwit/tensor/tensorops/Tensor0Ops.scala | 21 ++ .../dimwit/tensor/tensorops/Tensor1Ops.scala | 4 + .../dimwit/tensor/tensorops/Tensor2Ops.scala | 8 +- 12 files changed, 481 insertions(+), 162 deletions(-) diff --git a/core/src/main/scala/dimwit/autodiff/TensorTree.scala b/core/src/main/scala/dimwit/autodiff/TensorTree.scala index 90fcfb8..99eca73 100644 --- a/core/src/main/scala/dimwit/autodiff/TensorTree.scala +++ b/core/src/main/scala/dimwit/autodiff/TensorTree.scala @@ -2,12 +2,12 @@ package dimwit.autodiff import dimwit.jax.Jax import dimwit.tensor.* -import scala.compiletime.* -import scala.deriving.* - import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters +import scala.compiletime.* +import scala.deriving.* + /** A typeclass for structures that can be represented as a tree of tensors, * which can be mapped over. Most often, a tensor tree is used to structure * parameters of a model. diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index be29a81..2a88246 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -46,7 +46,8 @@ class Tensor[T <: Tuple: Labels, V] private[dimwit] ( /** The device on which the tensor is stored. */ lazy val device: Device = Device(jaxValue.device) - /** Converts the tensor to a different value type, if compatible. */ + /** Converts the tensor to the given vtype. + */ def asType[V2](vtype: VType[V2]): Tensor[T, V2] = new Tensor(Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(vtype.dtype))) /** Moves the tensor to a different device. */ diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala index 9193e3e..e38429f 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala @@ -12,6 +12,9 @@ import me.shadaj.scalapy.readwrite.Writer import scala.annotation.targetName +/** Provides extension methods for tensor contraction operations, + * including outer products and dot products. + */ object ContractionOps: extension [T <: Tuple: Labels, V](tensor: Tensor[T, V]) @@ -62,7 +65,7 @@ object ContractionOps: * {{{ * val t1: Tensor[("A", "B", "C"), Float] = ??? * val t2: Tensor[("D", "E, "F), Float] = ??? - * val result = t1.dot(Axis["B" ~ "F])(t2) + * val result = t1.dot(Axis[A]->Axis[D])(t2) * }}} */ @targetName("dotOn") diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala index 5f5f1ff..1656d25 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala @@ -13,8 +13,19 @@ import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import me.shadaj.scalapy.readwrite.Writer +/** Provides extension methods for convolution operations on tensors. + * Convolution operations are restricted to 1D, 2D and 3D convolutions, + * and support both standard and transposed convolutions. + */ object ConvolutionOps: + /** Padding options for convolution operations. + * SAME: Output size is the same as input size (with appropriate padding). + * VALID: No padding, output size is reduced based on kernel size. + * + * Refer to JAX documentation for more details on padding behavior. + * https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html + */ enum Padding: case SAME, VALID @@ -22,8 +33,15 @@ object ConvolutionOps: extension [S1: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: InChannel *: EmptyTuple, V]) + /** Computes the 1D convolution of this tensor with the specified kernel tensor. + * + * @param kernel - The convolution kernel + * @param stride - Stride for the convolution. + * @param padding - Padding mode for the convolution. + * @return A new tensor representing the result of the convolution operation. + */ def conv1d[OutChannel: Label]( - kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], + kernel: Tensor[(S1, InChannel, OutChannel), V], stride: Stride1[S1] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: OutChannel *: EmptyTuple, V] = @@ -48,6 +66,13 @@ object ConvolutionOps: extension [S1: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: OutChannel *: EmptyTuple, V]) + /** Computes the transposed 1D convolution of + * this tensor with the specified kernel tensor. + * @param kernel - The convolution kernel + * @param stride - Stride for the convolution. + * @param padding - Padding mode for the convolution. + * @return A new tensor representing the result of the transposed convolution operation. + */ def transposeConv1d[InChannel: Label]( kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride1[S1] | Int = 1, @@ -80,6 +105,13 @@ object ConvolutionOps: extension [S1: Label, S2: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V]) + /** Computes the 2D convolution of this tensor with the specified kernel tensor. + * + * @param kernel - The convolution kernel tensor with shape (S1, S2, InChannel, OutChannel). + * @param stride - Stride for the convolution. + * @param padding - Padding mode for the convolution. + * @return A new tensor representing the result of the convolution operation. + */ def conv2d[OutChannel: Label]( kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride2[S1, S2] | Int = 1, @@ -106,6 +138,13 @@ object ConvolutionOps: extension [S1: Label, S2: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V]) + /** Computes the transposed 2D convolution of this tensor with the specified kernel tensor. + * + * @param kernel - The convolution kernel tensor with shape (S1, S2, InChannel, OutChannel). + * @param stride - Stride for the convolution. + * @param padding - Padding mode for the convolution. + * @return A new tensor representing the result of the transposed convolution operation. + */ def transposeConv2d[InChannel: Label]( kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride2[S1, S2] | Int = 1, @@ -141,6 +180,13 @@ object ConvolutionOps: extension [S1: Label, S2: Label, S3: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V]) + /** Computes the 3D convolution of this tensor with the specified kernel tensor. + * + * @param kernel - The convolution kernel tensor + * @param stride - Stride for the convolution. + * @param padding - Padding mode for the convolution. + * @return A new tensor representing the result of the convolution operation. + */ def conv3d[OutChannel: Label]( kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride3[S1, S2, S3] | Int = 1, @@ -169,6 +215,13 @@ object ConvolutionOps: extension [S1: Label, S2: Label, S3: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V]) + /** Computes the transposed 3D convolution of this tensor with the specified kernel tensor. + * + * @param kernel - The convolution kernel tensor + * @param stride - Stride for the convolution. + * @param padding - Padding mode for the convolution. + * @return A new tensor representing the result of the transposed convolution operation. + */ def transposeConv3d[InChannel: Label]( kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride3[S1, S2, S3] | Int = 1, diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala index 1a30cba..4105633 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala @@ -15,8 +15,9 @@ import dimwit.tensor.VType import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast object ElementWiseOps: + // --------------------------------------------------------- - // General operations + // General operations on any tensor type // --------------------------------------------------------- /** Elementwise maximum of two tensors. */ @@ -25,6 +26,7 @@ object ElementWiseOps: /** Elementwise minimum of two tensors. */ def minimum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue)) + // extension methods for comparisons extension [T <: Tuple: Labels, V](t: Tensor[T, V]) def <(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) @@ -40,33 +42,64 @@ object ElementWiseOps: require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}") Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue)) + /** Casts the elements of this tensor to a tensor of type Bool. */ def asBool: Tensor[T, Bool] = t.asType(VType[Bool]) + + /** Casts the elements of this tensor to a tensor of the given boolean type. + * @param vtype the type to cast to + */ def asBoolean[NewV: IsBoolean](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + + /** Cast the elements of this tensor to a tensor of type Int32. */ def asInt32: Tensor[T, Int32] = t.asType(VType[Int32]) + + /** Casts the elements of this tensor to a tensor of the given integer type. + * + * @param vtype - the type to cast to + */ def asInt[NewV: IsInteger](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + + /** Casts the elements of this tensor to a tensor of type Float32. */ def asFloat32: Tensor[T, Float32] = t.asType(VType[Float32]) - def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) - // --------------------------------------------------------- - // IsNumber operations (IsFloat or IsInt) - // --------------------------------------------------------- + /** Casts the elements of this tensor to a tensor of the given floating point type. + * + * @param vtype - the type to cast to + */ + def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + /** Performs element-wise addition of two tensors of the same shape and type. */ def add[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) - def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) + /** Adds a scalar tensor to each element of a tensor. */ + def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], s: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, s.jaxValue)) + + /** Returns a new tensor with each element negated. */ def negate[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue)) + + /** Subtracts one tensor from another of the same shape and type, returning a new tensor. */ def subtract[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue)) - def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, t2.jaxValue)) - def multiply[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) - def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) + /** Subtracts a scalar tensor from each element of a tensor. */ + def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], s: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, s.jaxValue)) + + /** Multiplies two tensors of the same shape and type element-wise, returning a new tensor. */ + def multiply[T <: Tuple: Labels, V: IsNumber]( + t1: Tensor[T, V], + t2: Tensor[T, V] + ): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) + + /** Multiplies each element of a tensor by a scalar tensor, returning a new tensor. */ + def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], s: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, s.jaxValue)) + // extension methods for the binary operations on two tensors extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) def +(other: Tensor[T, V]): Tensor[T, V] = add(t, other) def -(other: Tensor[T, V]): Tensor[T, V] = subtract(t, other) def *(other: Tensor[T, V]): Tensor[T, V] = multiply(t, other) + // extension methods for the scalar operations. extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) def +![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(add) @@ -77,18 +110,24 @@ object ElementWiseOps: def *![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(multiply) def scale(other: Tensor0[V]): Tensor[T, V] = multiplyScalar(t, other) + // extension methods + extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) def abs: Tensor[T, V] = Tensor(Jax.jnp.abs(t.jaxValue)) def sign: Tensor[T, V] = Tensor(Jax.jnp.sign(t.jaxValue)) def clip(min: Tensor0[V], max: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.clip(t.jaxValue, min.jaxValue, max.jaxValue)) def pow(n: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.power(t.jaxValue, n.jaxValue)) // --------------------------------------------------------- - // IsFloat operations + // Operations on Floating tensors // --------------------------------------------------------- + /** Divides two tensors of the same shape and type element-wise, returning a new tensor. */ def divide[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue)) + + /** Divides each element of a tensor by a scalar tensor, returning a new tensor. */ def divideScalar[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue)) + // extension methods on floating tensors extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) def /(other: Tensor[T, V]): Tensor[T, V] = divide(t, other) @@ -115,10 +154,13 @@ object ElementWiseOps: // --------------------------------------------------------- // IsBoolean operations // --------------------------------------------------------- - extension [T <: Tuple: Labels, V: IsBoolean](t: Tensor[T, V]) + /** returns true if all elements of the tensor are true, false otherwise */ def all: Tensor0[V] = Tensor0(Jax.jnp.all(t.jaxValue)) + + /** return true if any element of the tensor is true, false otherwise */ def any: Tensor0[V] = Tensor0(Jax.jnp.any(t.jaxValue)) + /** returns a tensor of the same shape with each element negated (logical NOT) */ def unary_! : Tensor[T, V] = Tensor(Jax.jnp.logical_not(t.jaxValue)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala index 8e0aa04..eec4b3d 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala @@ -18,10 +18,6 @@ import me.shadaj.scalapy.readwrite.Reader import me.shadaj.scalapy.readwrite.Writer object FunctionalOps: - // ----------------------------------------------------------- - // 5. Functional Operations (Higher Order) - // Lifting functions over axes - // ----------------------------------------------------------- object ZipVmap: @@ -38,6 +34,23 @@ object FunctionalOps: type ShapesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractShape] type ValuesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractValue] + /** Zips the given given tensors along the specified axis + * and applies the function `f` to the zipped tensors. + * + * @param axis The axis along which to zip the tensors. + * @param tensors A tuple of tensors to be zipped. + * @param f A function that takes a tuple of tensors (with the specified axis removed) and returns a new tensor. + * @return A new tensor resulting from applying `f` + * + * Example usage: + * {{{ + * val tensor1: Tensor[(A, B), Int] = ... + * val tensor2: Tensor[(A, B), Int] = ... + * val result: Tensor[(A, C), Int] = ZipVmap.zipvmap(Axis[A])(tensor1, tensor2) { case (t1, t2) => + * // Perform operations on t1 and t2, which are tensors with axis A removed, and return a new tensor + * ... + * } + */ def zipvmap[L: Label, Inputs <: Tuple, OutShape <: Tuple: Labels, OutV]( axis: Axis[L] )( @@ -69,6 +82,29 @@ object FunctionalOps: extension [T <: Tuple: Labels, V](t: Tensor[T, V]) + /** Zips the current tensor with another tensor along the specified axis + * and applies the function `f` to the zipped tensors. + * + * @param axis The axis along which to zip the tensors. + * @param other The other tensor to be zipped with the current tensor. + * @param f A function that takes a tuple of tensors (with the specified axis removed) and returns a new tensor. + * @return A new tensor resulting from applying `f` to the zipped tensors. + */ + def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])( + other: Tensor[T2, V] + )(using + ev: SharedAxisRemover[(T, T2), L] + )( + f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV] + ): Tensor[L *: OutShape, OutV] = + ZipVmap.zipvmap(axis)(t, other)(f) + + /** Vectorized mapping over a specified axis of the tensor. + * + * @param axis The axis along which to apply the function `f`. + * @param f A function that takes a tensor with the specified axis removed and returns a new tensor. + * @return A new tensor resulting from applying `f` to each slice along the specified axis. + */ def vmap[VmapAxis: Label, OuterShape <: Tuple: Labels, V2]( axis: Axis[VmapAxis] )(using @@ -86,6 +122,13 @@ object FunctionalOps: Tensor(Jax.jax_helper.vmap(fpy, ev.index)(t.jaxValue)) + /** Apply a function independently to each 1D slice along a labeled axis. + * + * @param axis The axis along which to apply the function `f`. + * @param f A function f that is applied to each L-axis slice; it may rename that axis to NewL and change the element type. + * + * @return A new tensor resulting from applying `f` to each slice along the specified axis. + */ def vapply[L: Label, NewL, R <: Tuple, NewV]( axis: Axis[L] )( @@ -108,15 +151,13 @@ object FunctionalOps: ) ) - def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])( - other: Tensor[T2, V] - )(using - ev: SharedAxisRemover[(T, T2), L] - )( - f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV] - ): Tensor[L *: OutShape, OutV] = - ZipVmap.zipvmap(axis)(t, other)(f) - + /** Reduce a tensor along a labeled axis by applying a function to each 1D slice. + * + * @param axis The axis along which to reduce the tensor. + * @param f A function f that is applied to each L-axis slice; it must return a scalar (Tensor0). + * + * @return A new tensor resulting from reducing the specified axis. + */ def vreduce[L: Label]( axis: Axis[L] )( diff --git a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala index 0616325..db52ae1 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala @@ -19,6 +19,14 @@ object LinearAlgebraOps: extension [T <: Tuple: Labels, V](t: Tensor[T, V]) + /** Extracts the diagonal along the given two axes (with optional offset), + * replacing them by a new 1D axis labeled L1. + * + * @param axis1 The first axis along which to extract the diagonal. + * @param axis2 The second axis along which to extract the diagonal. + * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. + * @return A new tensor with the diagonal extracted, where the two specified axes are replaced by a new 1D axis labeled L1. + */ def diagonal[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using ev: AxesRemover[T, (L1, L2)], labels: Labels[ev.RemainingAxes] @@ -27,7 +35,19 @@ object LinearAlgebraOps: extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) + /** return the diagonal of the tensor `t` along the specified axes. + * The resulting 1D tensor has a single axis labeled L1, representing the diagonal index over the original (L1, L2) axes. + * + * @return A new tensor1 with representing the diagonal. It uses the Label of the first axis (L1) as the label for the resulting 1D tensor. + */ def diagonal: Tensor1[L1, V] = t.diagonal(0) + + /** return the diagonal of the tensor `t` along the specified axes. + * The resulting 1D tensor has a single axis labeled L1, representing the diagonal index over the original (L1, L2) axes. + * + * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. + * @return A new tensor1 with representing the diagonal. It uses the Label of the first axis (L1) as the label for the resulting 1D tensor. + */ def diagonal(offset: Int): Tensor1[L1, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset)) // --------------------------------------------------------- @@ -36,6 +56,15 @@ object LinearAlgebraOps: extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) + /** Computes the trace of the tensor `t` along the specified axes (L1, L2) with an optional offset. + * The resulting tensor has the two specified axes removed, and the remaining axes are preserved. + * + * @param axis1 The first axis along which to compute the trace. + * @param axis2 The second axis along which to compute the trace. + * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. + * + * @return A new tensor with the trace computed, where the two specified axes are removed, and the remaining axes are preserved. + */ def trace[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using ev: AxesRemover[T, (L1, L2)], labels: Labels[ev.RemainingAxes] @@ -43,17 +72,38 @@ object LinearAlgebraOps: extension [L1: Label, L2: Label, V: IsNumber](t: Tensor2[L1, L2, V]) + /** Computes the trace of the tensor + */ def trace: Tensor0[V] = t.trace(0) - def trace(offset: Int): Tensor0[V] = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) - // --------------------------------------------------------- - // IsFloat operations - // --------------------------------------------------------- + /** Computes the trace of the tensor with an optional offset. + * + * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, + * while negative values indicate diagonals below it. + */ + def trace(offset: Int): Tensor0[V] = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) + /** Computes the L2 norm of the tensor t. + */ def norm: Tensor0[V] = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) + + /** Computes the inverse of the tensor t along the last two axes. + * The first axes of the tensors are preserved, while the last + * two axes are replaced by their inverses. + * The tensor must be square along the last two axes. + * + * @return a new tensor with the same shape as t, but with the last two axes replaced by their inverses. + */ def inv: Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) + + /** Computes the determinant of the tensor `t` along the specified axes (L1, L2) + * + * @param axis1 The first axis along which to compute the determinant. + * @param axis2 The second axis along which to compute the determinant. + * @return A new tensor with the determinant computed, where the two specified axes are removed + */ def det[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2])(using ev: AxesRemover[T, (L1, L2)], labels: Labels[ev.RemainingAxes] @@ -67,5 +117,5 @@ object LinearAlgebraOps: Tensor(Jax.jnp.linalg.det(moved)) extension [L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V]) - + /** computes the determinant of the 2-D tensor t */ def det: Tensor0[V] = Tensor0(Jax.jnp.linalg.det(t.jaxValue)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala index ff3131b..7481004 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala @@ -20,41 +20,37 @@ import me.shadaj.scalapy.readwrite.Writer object ReductionOps: - // --------------------------------------------------------- - // IsNumber operations (IsFloat or IsInt) - // --------------------------------------------------------- - extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - // --- Sum --- - def sum: Tensor0[V] = Tensor0(Jax.jnp.sum(t.jaxValue)) - def sum[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.index)) + /** sums the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def sum[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.indices.toPythonProxy)) + def sum[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.index)) + def sum: Tensor0[V] = Tensor0(Jax.jnp.sum(t.jaxValue)) - // --- Max --- - def max: Tensor0[V] = Tensor0(Jax.jnp.max(t.jaxValue)) - def max[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.index)) + /** takes the maximum of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def max[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.indices.toPythonProxy)) + def max[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.index)) + def max: Tensor0[V] = Tensor0(Jax.jnp.max(t.jaxValue)) - // --- Min --- - def min: Tensor0[V] = Tensor0(Jax.jnp.min(t.jaxValue)) - def min[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.index)) + /** takes the minimum of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def min[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.indices.toPythonProxy)) + def min[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.index)) + def min: Tensor0[V] = Tensor0(Jax.jnp.min(t.jaxValue)) - // --- Argmax --- - def argmax: Tensor0[Int32] = Tensor0(Jax.jnp.argmax(t.jaxValue)) - def argmax[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) + /** argument of the maximum of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def argmax[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argmax[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) + def argmax: Tensor0[Int32] = Tensor0(Jax.jnp.argmax(t.jaxValue)) - // --- Argmin --- - def argmin: Tensor0[Int32] = Tensor0(Jax.jnp.argmin(t.jaxValue)) - def argmin[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) + /** argument of the minimum of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def argmin[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argmin[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) + def argmin: Tensor0[Int32] = Tensor0(Jax.jnp.argmin(t.jaxValue)) - // --- Argsort --- - def argsort: Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue)) - def argsort[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.index)) + /** Returns a tensor of indices that would sort `t` along the specified axes */ def argsort[Inputs <: Tuple](axes: Inputs)(using ev: AxisIndices[T, UnwrapAxes[Inputs]]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argsort[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.index)) + def argsort: Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue)) // --------------------------------------------------------- // IsFloat operations (IsFloat or IsInt) @@ -62,30 +58,32 @@ object ReductionOps: extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) - // --- Mean --- - def mean: Tensor0[V] = Tensor0(Jax.jnp.mean(t.jaxValue)) - def mean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.index)) + /** computes the mean of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def mean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.indices.toPythonProxy)) + def mean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.index)) + def mean: Tensor0[V] = Tensor0(Jax.jnp.mean(t.jaxValue)) - // --- Std --- - def std: Tensor0[V] = Tensor0(Jax.jnp.std(t.jaxValue)) - def std[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.index)) + /** computes the mean of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def std[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.indices.toPythonProxy)) + def std[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.index)) + def std: Tensor0[V] = Tensor0(Jax.jnp.std(t.jaxValue)) - // --- Quantile --- - def quantile(q: Float): Tensor0[V] = Tensor0(Jax.jnp.quantile(t.jaxValue, q)) - def quantile[L: Label](q: Float, axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.index)) + /** computes the qth quantile of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def quantile[Inputs <: Tuple](q: Float, axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.indices.toPythonProxy)) + def quantile[L: Label](q: Float, axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.index)) + def quantile(q: Float): Tensor0[V] = Tensor0(Jax.jnp.quantile(t.jaxValue, q)) - // --- Median --- + /** computes the median of the tensor `t` along the specified axes, returning a new tensor with those axes removed. */ def median: Tensor0[V] = Tensor0(Jax.jnp.median(t.jaxValue)) def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) + /** computes the mean of the tensor `t` along the specified axes, ignoring na values and returning a new tensor with those axes removed. */ def nanmean: Tensor0[V] = Tensor0(Jax.jnp.nanmean(t.jaxValue)) def nanmean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.index)) def nanmean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.indices.toPythonProxy)) - def nanmedian: Tensor0[V] = Tensor0(Jax.jnp.nanmedian(t.jaxValue)) - def nanmedian[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.index)) + /** computes the median of the tensor `t` along the specified axes, ignoring na values and returning a new tensor with those axes removed. */ def nanmedian[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.indices.toPythonProxy)) + def nanmedian[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.index)) + def nanmedian: Tensor0[V] = Tensor0(Jax.jnp.nanmedian(t.jaxValue)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala index 700d074..afc7f0a 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala @@ -43,11 +43,6 @@ import me.shadaj.scalapy.readwrite.Writer import scala.annotation.implicitNotFound import scala.util.NotGiven -// ----------------------------------------------------------- -// 4. Structural Operations (Isomorphisms) -// Permutations and Views: T1 -> T2 (Size(T1) == Size(T2)) -// ----------------------------------------------------------- - object StructuralOps: private object Util: @@ -134,6 +129,15 @@ object StructuralOps: import Util.* object TensorWhere: + /** Returns a new tensor where elements are selected from `x` or `y` + * depending on the boolean condition. + * + * @param condition A tensor of boolean values that determines which elements to select. + * @param x A tensor from which to select elements when the condition is true. + * @param y A tensor from which to select elements when the condition is false. + * + * @return A new tensor with elements from `x` where the condition is true, and elements from `y` where the condition is false. + */ def where[T <: Tuple: Labels, V]( condition: Tensor[T, Bool], x: Tensor[T, V], @@ -143,12 +147,35 @@ object StructuralOps: export TensorWhere.where + /** Returns a new tensor with the upper triangular part of the input tensor, + * setting elements below the kth diagonal to zero. + * + * @param tensor The input tensor from which to extract the upper triangular part. + * @param kthDiagonal The diagonal above which to set elements to zero. + * + * @return A new tensor with the upper triangular part of the input tensor. + */ def triu[T <: Tuple: Labels, V](tensor: Tensor[T, V], kthDiagonal: Int = 0): Tensor[T, V] = Tensor(Jax.jnp.triu(tensor.jaxValue, k = kthDiagonal)) + /** Returns a new tensor with the lower triangular part of the input tensor, + * setting elements above the kth diagonal to zero. + * + * @param tensor The input tensor from which to extract the lower triangular part. + * @param kthDiagonal The diagonal below which to set elements to zero. + * + * @return A new tensor with the lower triangular part of the input tensor. + */ def tril[T <: Tuple: Labels, V](tensor: Tensor[T, V], kthDiagonal: Int = 0): Tensor[T, V] = Tensor(Jax.jnp.tril(tensor.jaxValue, k = kthDiagonal)) + /** Stacks a sequence of tensors along a new axis. + * The new axis is inserted as the first axis of the resulting tensor. + * + * @param tensors A sequence of tensors to be stacked. All tensors must have the same shape and type. + * @param newAxis The new axis to be inserted. + * @return A new tensor with the stacked tensors. + */ def stack[L: Label, T <: Tuple: Labels, V]( tensors: Seq[Tensor[T, V]], newAxis: Axis[L] @@ -158,6 +185,14 @@ object StructuralOps: val stackedJaxValue = Jax.jnp.stack(jaxValuesSeq, axis = 0) Tensor(stackedJaxValue) + /** Stacks a sequence of tensors along a new axis, inserting the new axis + * after the specified existing axis. + * + * @param tensors A sequence of tensors to be stacked. All tensors must have the same shape and type. + * @param newAxis The new axis to be inserted. + * @param afterAxis The existing axis after which the new axis will be inserted. + * @return A new tensor with the stacked tensors. + */ def stack[NewL, L, T <: Tuple: Labels, V]( tensors: Seq[Tensor[T, V]], newAxis: Axis[NewL], @@ -176,6 +211,14 @@ object StructuralOps: val names = newNames.toSeq Tensor(stackedJaxValue) + /** Concatenates a sequence of tensors along the specified axis, returning a new tensor with the concatenated values. + * + * @param tensors A sequence of tensors to be concatenated. + * All tensors must have the same shape and type, + * except for the dimension corresponding to the concatenation axis. + * @param concatAxis The axis along which the tensors will be concatenated. + * @return A new tensor with the concatenated values. + */ def concatenate[L: Label, T <: Tuple: Labels, V]( tensors: Seq[Tensor[T, V]], concatAxis: Axis[L] @@ -188,6 +231,14 @@ object StructuralOps: val concatenatedJaxValue = Jax.jnp.concatenate(jaxValuesSeq, axis = axisIdx) Tensor(concatenatedJaxValue) + /** Concatenates two tensors along the specified axis, + * returning a new tensor with the concatenated values. + * + * @param t1 The first tensor to be concatenated. + * @param t2 The second tensor to be concatenated. + * @param concatAxis The axis along which the tensors will be concatenated. + * @return A new tensor with the concatenated values. + */ def concatenate[L: Label, T <: Tuple: Labels, V]( t1: Tensor[T, V], t2: Tensor[T, V], @@ -196,6 +247,18 @@ object StructuralOps: axisIndex: AxisIndex[T, L] ): Tensor[T, V] = concatenate(Seq(t1, t2), concatAxis) + /** Concatenates two tensors along the common axis, returning a new tensor with the concatenated values. + */ + def concatenate[T1 <: Tuple, T2 <: Tuple, V, R <: Tuple]( + t1: Tensor[T1, V], + t2: Tensor[T2, V] + )(using + canConcat: ValidConcat.Aux[T1, T2, R], + label: Labels[R] + ): Tensor[R, V] = + val jaxValues = List(t1.jaxValue, t2.jaxValue).toPythonProxy + Tensor(Jax.jnp.concatenate(jaxValues, axis = canConcat.index)) + trait ValidConcat[T1 <: Tuple, T2 <: Tuple]: type Out <: Tuple def index: Int @@ -215,16 +278,6 @@ object StructuralOps: type Out = (H1 |+| H2) *: Tail def index: Int = 0 - def concatenate[T1 <: Tuple, T2 <: Tuple, V, R <: Tuple]( - t1: Tensor[T1, V], - t2: Tensor[T2, V] - )(using - canConcat: ValidConcat.Aux[T1, T2, R], - label: Labels[R] - ): Tensor[R, V] = - val jaxValues = List(t1.jaxValue, t2.jaxValue).toPythonProxy - Tensor(Jax.jnp.concatenate(jaxValues, axis = canConcat.index)) - type SplitComponents[L, I <: Tuple] <: Tuple = I match case EmptyTuple => L *: EmptyTuple case _ *: tail => L *: SplitComponents[L, tail] @@ -282,6 +335,19 @@ object StructuralOps: extension [T <: Tuple, V](tensor: Tensor[T, V]) + /** takes a concatenated tensor and splits it into a tuple of tensors along the specified axis, + * using the provided dimensions for each component. + * + * @param axis The axis along which to deconcatenate the tensor. + * @param dims A tuple of AxisExtent specifying the sizes of each component along the specified axis. + * @return A tuple of tensors corresponding to the deconcatenated components. + * + * Example usage: + * {{{ + * val t : Tensor2[Axis[A], Axis[B |+| C]) = ??? + * val (partB, partC) = t.deconcatenate(Axis[B |+| C], (Axis[B] -> 2, Axis[C] -> 3) + * }}} + */ def deconcatenate[L, Dims <: Tuple, Comps <: Tuple, Result]( axis: Axis[L], dims: Dims @@ -306,77 +372,6 @@ object StructuralOps: maker.apply(splitArrays, decon.labels, originalNames, axisIndex.index) - /** Splits the tensor along the specified axis at the given indices, returning a tuple of tensors corresponding to the splits. - * - * @param selector of the form Axis[L].at((idx1, idx2, ...)) specifying the axis to split and the indices to split at - * @return the tuple of tensors resulting from the split - */ - def split[L: Label, I <: NonEmptyTuple](selector: AxisAtTupleIndices[L, I])(using - axisIndex: AxisIndex[T, L], - maker: TensorTupleMaker[SplitComponents[L, I], T, L, V], - labels: Labels[T] - ): maker.Out = - val splitList = selector.indices.toList.asInstanceOf[List[Int]] - val pyIndices = me.shadaj.scalapy.py.Dynamic.global.list(splitList.toPythonProxy) - val splitArrays = Jax.jnp.split(tensor.jaxValue, pyIndices, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] - val axisLabelInstance = summon[Label[L]] - val compLabels = List.fill(splitList.size + 1)(axisLabelInstance.asInstanceOf[Label[?]]) - maker.apply(splitArrays, compLabels, labels.names.toSeq, axisIndex.index) - - /** Splits the tensor along the specified axis at the given index, - * returning a tuple of two tensors corresponding to the splits. - * - * @param selector of the form Axis[L].at(idx) specifying the axis to split and the index to split at - * @return a tuple of two tensors resulting from the split - */ - def split[L: Label](selector: AxisAtIndex[L])(using - axisIndex: AxisIndex[T, L], - maker: TensorTupleMaker[L *: L *: EmptyTuple, T, L, V], - labels: Labels[T] - ): maker.Out = - split(AxisAtTupleIndices(selector.axis, Tuple1(selector.index))) - - private def calcPyIndices[Inputs <: Tuple]( - inputs: Inputs, - targetDims: List[Int] - ) = - - val PySlice = py.Dynamic.global.slice - val Colon = PySlice(py.None) - val rank = tensor.shape.rank - val indicesBuffer = collection.mutable.ArrayBuffer.fill[py.Any](rank)(Colon) - - val inputList = inputs.toList.asInstanceOf[List[Any]] - - targetDims.zip(inputList).foreach { case (dimIndex, input) => - val dimSize = tensor.shape.dimensions(dimIndex) - input match - // New AxisSelector types - case AxisAtIndex(_, idx) => - indicesBuffer(dimIndex) = py.Any.from(idx) - case AxisAtRange(_, range) => - indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) - case AxisAtIndices(_, indices) => - indicesBuffer(dimIndex) = indices.map(py.Any.from).toPythonCopy // TODO find out why Copy is needed here - case AxisAtTupleIndices(_, indices) => - indicesBuffer(dimIndex) = indices.toList.asInstanceOf[List[Int]].map(py.Any.from).toPythonCopy - case AxisAtTensorIndex(_, tensorIdx) => - indicesBuffer(dimIndex) = tensorIdx.jaxValue - // Backward compatibility with tuples - case (_, sliceIndex) => - sliceIndex match - case sliceSeq: List[Int] @unchecked => - indicesBuffer(dimIndex) = sliceSeq.map(py.Any.from).toPythonProxy - case range: Range @unchecked => - indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) - case idx: Int => - indicesBuffer(dimIndex) = py.Any.from(idx) - case tensorId: Tensor0[Int32] @unchecked => - indicesBuffer(dimIndex) = tensorId.jaxValue - } - - Jax.Dynamic.global.tuple(indicesBuffer.toSeq.toPythonProxy) - /** Flattens all axes of the tensor into a single axis. * The resulting tensor will have a single axis named by concatenating the original axis names with "*". * @@ -463,6 +458,11 @@ object StructuralOps: ) ) + /** Transposes the tensor according to the specified new order of axes. + * + * @param NewOrder A tuple representing the new order of axes for the tensor. + * @return A new tensor with the axes transposed according to the specified order. + */ def transpose[NewOrder <: Tuple, Status <: ValidationResult](newOrder: NewOrder)(using ev: AxisIndices[T, UnwrapAxes[NewOrder]], newLabels: Labels[UnwrapAxes[NewOrder]] @@ -472,7 +472,78 @@ object StructuralOps: val indices = ev.indices Tensor(Jax.jnp.transpose(tensor.jaxValue, indices.toPythonProxy)) - /** Splits the tensor along the specified axis at the given indices, returning a sequence of tensors corresponding to the splits. + /** Splits the tensor along the specified axis at the given indices, returning a tuple of tensors corresponding to the splits. + * + * @param selector of the form Axis[L].at((idx1, idx2, ...)) specifying the axis to split and the indices to split at + * @return the tuple of tensors resulting from the split + */ + def split[L: Label, I <: NonEmptyTuple](selector: AxisAtTupleIndices[L, I])(using + axisIndex: AxisIndex[T, L], + maker: TensorTupleMaker[SplitComponents[L, I], T, L, V], + labels: Labels[T] + ): maker.Out = + val splitList = selector.indices.toList.asInstanceOf[List[Int]] + val pyIndices = me.shadaj.scalapy.py.Dynamic.global.list(splitList.toPythonProxy) + val splitArrays = Jax.jnp.split(tensor.jaxValue, pyIndices, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] + val axisLabelInstance = summon[Label[L]] + val compLabels = List.fill(splitList.size + 1)(axisLabelInstance.asInstanceOf[Label[?]]) + maker.apply(splitArrays, compLabels, labels.names.toSeq, axisIndex.index) + + /** Splits the tensor along the specified axis at the given index, + * returning a tuple of two tensors corresponding to the splits. + * + * @param selector of the form Axis[L].at(idx) specifying the axis to split and the index to split at + * @return a tuple of two tensors resulting from the split + */ + def split[L: Label](selector: AxisAtIndex[L])(using + axisIndex: AxisIndex[T, L], + maker: TensorTupleMaker[L *: L *: EmptyTuple, T, L, V], + labels: Labels[T] + ): maker.Out = + split(AxisAtTupleIndices(selector.axis, Tuple1(selector.index))) + + private def calcPyIndices[Inputs <: Tuple]( + inputs: Inputs, + targetDims: List[Int] + ) = + + val PySlice = py.Dynamic.global.slice + val Colon = PySlice(py.None) + val rank = tensor.shape.rank + val indicesBuffer = collection.mutable.ArrayBuffer.fill[py.Any](rank)(Colon) + + val inputList = inputs.toList.asInstanceOf[List[Any]] + + targetDims.zip(inputList).foreach { case (dimIndex, input) => + val dimSize = tensor.shape.dimensions(dimIndex) + input match + // New AxisSelector types + case AxisAtIndex(_, idx) => + indicesBuffer(dimIndex) = py.Any.from(idx) + case AxisAtRange(_, range) => + indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) + case AxisAtIndices(_, indices) => + indicesBuffer(dimIndex) = indices.map(py.Any.from).toPythonCopy // TODO find out why Copy is needed here + case AxisAtTupleIndices(_, indices) => + indicesBuffer(dimIndex) = indices.toList.asInstanceOf[List[Int]].map(py.Any.from).toPythonCopy + case AxisAtTensorIndex(_, tensorIdx) => + indicesBuffer(dimIndex) = tensorIdx.jaxValue + // Backward compatibility with tuples + case (_, sliceIndex) => + sliceIndex match + case sliceSeq: List[Int] @unchecked => + indicesBuffer(dimIndex) = sliceSeq.map(py.Any.from).toPythonProxy + case range: Range @unchecked => + indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) + case idx: Int => + indicesBuffer(dimIndex) = py.Any.from(idx) + case tensorId: Tensor0[Int32] @unchecked => + indicesBuffer(dimIndex) = tensorId.jaxValue + } + + Jax.Dynamic.global.tuple(indicesBuffer.toSeq.toPythonProxy) + + /** Unstacks the tensor along the specified axis at the given indices, returning a sequence of tensors corresponding to the splits. * * @param unstackAxis the axis to split, specified as an Axis (e.g. Axis[Ax1]) * @return a sequence of tensors resulting from the split, each with the specified axis removed @@ -486,6 +557,9 @@ object StructuralOps: val unstacked = Jax.jnp.split(tensor.jaxValue, tensor.shape.dimensions(axisIdx), axis = axisIdx).as[Seq[Jax.PyDynamic]] unstacked.map(x => Tensor[ev.RemainingAxes, V](x)) + /** splits the tensor into chunks of the specified size along the given axis + * returning a sequence of tensors corresponding to the chunks. + */ def chunk[splitL: Label](splitAxis: Axis[splitL], chunkSize: Int)(using labels: Labels[T], axisIndex: AxisIndex[T, splitL] @@ -493,6 +567,12 @@ object StructuralOps: val res = Jax.jnp.split(tensor.jaxValue, chunkSize, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] res.map(x => Tensor[T, V](x)) + /** Slices the tensor according to the specified inputs, + * removing the specified labels from the resulting tensor. + * + * @param inputs A tuple of inputs specifying how to slice the tensor. + * @return The sliced tensor with the specified labels removed from its shape. + */ def slice[Inputs <: Tuple, LabelsToRemove <: Tuple]( inputs: Inputs )(using @@ -503,7 +583,11 @@ object StructuralOps: val pyIndices = tensor.calcPyIndices(inputs, ev.indices) Tensor(tensor.jaxValue.bracketAccess(pyIndices)) - // Convenience overload for AxisAtIndex + /** Slice the given tensor, specifying the axis and index to slice at. + * + * @param selector An AxisAtIndex specifying the axis and index to slice at. + * @return A sliced tensor with the specified axis removed from its shape. + */ def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtIndex[L] )(using @@ -512,7 +596,11 @@ object StructuralOps: labels: Labels[ev.RemainingAxes] ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - // Convenience overload for AxisAtRange + /** Slice the given tensor, specifying the axis and a given range to slice at. + * + * @param selector An AxisAtRange specifying the axis and range to slice at. + * @return A sliced tensor with the specified axis removed from its shape. + */ def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtRange[L] )(using @@ -521,7 +609,11 @@ object StructuralOps: labels: Labels[ev.RemainingAxes] ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - // Convenience overload for AxisAtIndices + /** Slice the given tensor, specifying the axis and a list of indices to slice at. + * + * @param selector An AxisAtIndices specifying the axis and indices to slice at. + * @return A sliced tensor with the specified axis removed from its shape. + */ def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtIndices[L] )(using @@ -530,7 +622,11 @@ object StructuralOps: labels: Labels[ev.RemainingAxes] ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - // Convenience overload for AxisAtTensorIndex + /** Slice the given tensor, specifying the axis and a tensor of indices to slice at. + * + * @param selector An AxisAtTensorIndex specifying the axis and tensor of indices to slice at. + * @return A sliced tensor with the specified axis removed from its shape. + */ def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtTensorIndex[L] )(using @@ -539,7 +635,11 @@ object StructuralOps: labels: Labels[ev.RemainingAxes] ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - // Convenience overload for AxisAtTupleIndices + /** Slice the given tensor, specifying the axis and a tuple of indices to slice at. + * + * @param selector An AxisAtTupleIndices specifying the axis and tuple of indices to slice at. + * @return A sliced tensor with the specified axis removed from its shape. + */ def slice[L, U <: NonEmptyTuple, LabelsToRemove <: Tuple]( selector: AxisAtTupleIndices[L, U] )(using diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala index 32644db..8ea24de 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala @@ -16,36 +16,57 @@ object Tensor0Ops: ) extension (scalar: Tensor0[Bool]) + /** return the underlying boolean value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Boolean = checkTracer(scalar) scalar.jaxValue.item().as[Boolean] extension (scalar: Tensor0[Int8]) + /** return the underlying Byte value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Byte = checkTracer(scalar) scalar.jaxValue.item().as[Byte] extension (scalar: Tensor0[Int16]) + /** return the underlying Short value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Short = checkTracer(scalar) scalar.jaxValue.item().as[Int].toShort extension (scalar: Tensor0[Int32]) + /** return the underlying Int value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Int = checkTracer(scalar) scalar.jaxValue.item().as[Int] extension (scalar: Tensor0[Int64]) + /** return the underlying Long value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Long = checkTracer(scalar) scalar.jaxValue.item().as[Long] extension (scalar: Tensor0[Float32]) + /** return the underlying Float value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Float = checkTracer(scalar) scalar.jaxValue.item().as[Float] extension (scalar: Tensor0[Float64]) + /** return the underlying Double value of the scalar tensor. + * Attention! Breaks the computational graph. + */ def item: Double = checkTracer(scalar) scalar.jaxValue.item().as[Double] diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala index 98b2045..93eb1b4 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala @@ -17,6 +17,10 @@ object Tensor1Ops: extension [L, V](t: Tensor1[L, V]) + /** relabels the axis of the tensor to a new label. The underlying data remains unchanged. + * @param newAxis the new axis label + * @return a new Tensor1 with the same data but with the axis relabeled to newAxis + */ def relabelTo[NewL: Label](newAxis: Axis[NewL]): Tensor1[NewL, V] = Tensor[Tuple1[NewL], V](t.jaxValue) // TODO generalize to TensorN (like slice) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala index 0bf95e9..c376c82 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala @@ -10,8 +10,14 @@ object Tensor2Ops: extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) - // Support .transpose without arguments for 2D tensors while keeping (not shadowing) the general .transpose with arguments + /** Transposes the two axes of the tensor, effectively swapping their positions. */ def transpose: Tensor2[L2, L1, V] = t.transpose(Axis[L2], Axis[L1]) + + /** Transposes the two axes of the tensor, effectively swapping their positions. + * @param axis2 the first axis to swap + * @param axis1 the second axis to swap + * @return a new Tensor2 with the specified axes transposed + */ def transpose(axis2: Axis[L2], axis1: Axis[L1]): Tensor2[L2, L1, V] = StructuralOps.transpose(t)(axis2, axis1) extension [L1, L2, V, X](t: Tensor2[L1, L2, V])(using ev: HasScalar[V, X])