From 1603a9803626442dcf14a83b5e044f46b97927b8 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Jun 2026 17:50:47 +0200 Subject: [PATCH 1/4] refactor tensorops into separate files --- .../main/scala/dimwit/tensor/TensorOps.scala | 1589 +---------------- .../main/scala/dimwit/tensor/ValueOps.scala | 37 + .../tensor/tensorops/ContractionOps.scala | 98 + .../tensor/tensorops/ConvolutionOps.scala | 214 +++ .../tensor/tensorops/ElementWiseOps.scala | 124 ++ .../tensor/tensorops/FunctionalOps.scala | 153 ++ .../tensor/tensorops/LinearAlgebraOps.scala | 83 + .../tensor/tensorops/ReductionOps.scala | 90 + .../tensor/tensorops/StructuralOps.scala | 804 +++++++++ .../dimwit/tensor/tensorops/Tensor0Ops.scala | 61 + .../dimwit/tensor/tensorops/Tensor1Ops.scala | 49 + .../dimwit/tensor/tensorops/Tensor2Ops.scala | 43 + .../dimwit/tensor/tensorops/Tensor3Ops.scala | 39 + .../tensor/tensorops/TensorOpsUtils.scala | 40 + .../tensor/TensorOpsConvolutionSuite.scala | 2 +- 15 files changed, 1850 insertions(+), 1576 deletions(-) create mode 100644 core/src/main/scala/dimwit/tensor/ValueOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala create mode 100644 core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 9a9e6950..a9cc74c1 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -3,12 +3,11 @@ package dimwit.tensor import dimwit.DType.* import dimwit.DType.given import dimwit.OnError -import dimwit.jax.Einops import dimwit.jax.Jax import dimwit.tensor.HasScalar import dimwit.tensor.{Label, Labels} import dimwit.tensor.ShapeTypeHelpers.* -import dimwit.tensor.TensorOps.Functional.ZipVmap.{ShapesOf, TensorsOf} +import dimwit.tensor.TensorOps.ZipVmap.{ShapesOf, TensorsOf} import dimwit.tensor.TupleHelpers.* import dimwit.{`|*|`, `|+|`} @@ -23,10 +22,11 @@ import scala.util.NotGiven import Tuple.:* import Tuple.++ +import dimwit.tensor.tensorops.StructuralOps object TensorOps: - import TensorOpsUtil.* + import dimwit.tensor.tensorops.TensorOpsUtil.* /** Typeclass to map a type V to its corresponding DType. */ @@ -68,1579 +68,18 @@ object TensorOps: object IsBoolean: def apply[V](using ev: IsBoolean[V]): IsBoolean[V] = ev - // ----------------------------------------------------------- - // 1. Elementwise Operations (The Field) - // Preserves Shape: T -> T - // ----------------------------------------------------------- - object Elementwise: + export tensorops.ElementWiseOps.* + export tensorops.ReductionOps.* + export tensorops.ContractionOps.* + export tensorops.ConvolutionOps.* + export tensorops.LinearAlgebraOps.* + export tensorops.StructuralOps.* + export tensorops.FunctionalOps.* - // --------------------------------------------------------- - // General operations - // --------------------------------------------------------- - - /** Elementwise maximum of two tensors. */ - def maximum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.maximum(t1.jaxValue, t2.jaxValue)) - - /** Elementwise minimum of two tensors. */ - def minimum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue)) - - extension [T <: Tuple: Labels, V](t: Tensor[T, V]) - - def <(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) - def <=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) - def >(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) - def >=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) - - /** Checks full array equality, returns true if all elements are equal */ - def ===(other: Tensor[T, V]): Tensor0[Bool] = Tensor0(Jax.jnp.array_equal(t.jaxValue, other.jaxValue)) - - /** Elementwise equality, returns a tensor of bools indicating which elements are equal */ - def elementEquals(other: Tensor[T, V]): Tensor[T, Bool] = - require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}") - Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue)) - - def asBool: Tensor[T, Bool] = t.asType(VType[Bool]) - def asBoolean[NewV: IsBoolean](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) - def asInt32: Tensor[T, Int32] = t.asType(VType[Int32]) - def asInt[NewV: IsInteger](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) - def asFloat32: Tensor[T, Float32] = t.asType(VType[Float32]) - def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) - - // --------------------------------------------------------- - // IsNumber operations (IsFloat or IsInt) - // --------------------------------------------------------- - - def add[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) - def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) - - def negate[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue)) - def subtract[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue)) - def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, t2.jaxValue)) - - def multiply[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) - def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) - - extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - - def +(other: Tensor[T, V]): Tensor[T, V] = add(t, other) - def -(other: Tensor[T, V]): Tensor[T, V] = subtract(t, other) - def *(other: Tensor[T, V]): Tensor[T, V] = multiply(t, other) - - extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - - def +![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(add) - - def unary_- : Tensor[T, V] = negate(t) - def -![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(subtract) - - def *![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(multiply) - def scale(other: Tensor0[V]): Tensor[T, V] = multiplyScalar(t, other) - - def abs: Tensor[T, V] = Tensor(Jax.jnp.abs(t.jaxValue)) - def sign: Tensor[T, V] = Tensor(Jax.jnp.sign(t.jaxValue)) - def clip(min: Tensor0[V], max: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.clip(t.jaxValue, min.jaxValue, max.jaxValue)) - def pow(n: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.power(t.jaxValue, n.jaxValue)) - - // --------------------------------------------------------- - // IsFloat operations - // --------------------------------------------------------- - - def divide[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue)) - def divideScalar[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue)) - - extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) - - def /(other: Tensor[T, V]): Tensor[T, V] = divide(t, other) - def /![O <: Tuple](other: Tensor[O, V])(using join: Broadcast[T, O, V]): Tensor[join.Out, V] = join.applyTo(t, other)(divide) - - def sqrt: Tensor[T, V] = Tensor(Jax.jnp.sqrt(t.jaxValue)) - def exp: Tensor[T, V] = Tensor(Jax.jnp.exp(t.jaxValue)) - def log: Tensor[T, V] = Tensor(Jax.jnp.log(t.jaxValue)) - def sin: Tensor[T, V] = Tensor(Jax.jnp.sin(t.jaxValue)) - def cos: Tensor[T, V] = Tensor(Jax.jnp.cos(t.jaxValue)) - def tanh: Tensor[T, V] = Tensor(Jax.jnp.tanh(t.jaxValue)) - - def approxEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor0[Bool] = approxElementEquals(other, tolerance).all - def approxElementEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor[T, Bool] = - Tensor( - Jax.jnp.allclose( - t.jaxValue, - other.jaxValue, - atol = tolerance, - rtol = tolerance - ) - ) - - // --------------------------------------------------------- - // IsBoolean operations - // --------------------------------------------------------- - - extension [T <: Tuple: Labels, V: IsBoolean](t: Tensor[T, V]) - - def all: Tensor0[V] = Tensor0(Jax.jnp.all(t.jaxValue)) - def any: Tensor0[V] = Tensor0(Jax.jnp.any(t.jaxValue)) - - def unary_! : Tensor[T, V] = Tensor(Jax.jnp.logical_not(t.jaxValue)) - - end Elementwise - - // ----------------------------------------------------------- - // 2. Reduction Operations (The Monoid) - // Reduces Rank: T -> T - {Axis} - // ----------------------------------------------------------- - object Reduction: - - // --------------------------------------------------------- - // IsNumber operations (IsFloat or IsInt) - // --------------------------------------------------------- - - extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - - // --- Sum --- - def sum: Tensor0[V] = Tensor0(Jax.jnp.sum(t.jaxValue)) - def sum[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.index)) - def sum[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Max --- - def max: Tensor0[V] = Tensor0(Jax.jnp.max(t.jaxValue)) - def max[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.index)) - def max[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Min --- - def min: Tensor0[V] = Tensor0(Jax.jnp.min(t.jaxValue)) - def min[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.index)) - def min[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Argmax --- - def argmax: Tensor0[Int32] = Tensor0(Jax.jnp.argmax(t.jaxValue)) - def argmax[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) - def argmax[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Argmin --- - def argmin: Tensor0[Int32] = Tensor0(Jax.jnp.argmin(t.jaxValue)) - def argmin[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) - def argmin[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Argsort --- - def argsort: Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue)) - def argsort[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.index)) - def argsort[Inputs <: Tuple](axes: Inputs)(using ev: AxisIndices[T, UnwrapAxes[Inputs]]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --------------------------------------------------------- - // IsFloat operations (IsFloat or IsInt) - // --------------------------------------------------------- - - extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) - - // --- Mean --- - def mean: Tensor0[V] = Tensor0(Jax.jnp.mean(t.jaxValue)) - def mean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.index)) - def mean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Std --- - def std: Tensor0[V] = Tensor0(Jax.jnp.std(t.jaxValue)) - def std[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.index)) - def std[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.indices.toPythonProxy)) - - // --- Quantile --- - def quantile(q: Float): Tensor0[V] = Tensor0(Jax.jnp.quantile(t.jaxValue, q)) - def quantile[L: Label](q: Float, axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.index)) - def quantile[Inputs <: Tuple](q: Float, axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.indices.toPythonProxy)) - - // --- Median --- - def median: Tensor0[V] = Tensor0(Jax.jnp.median(t.jaxValue)) - def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) - def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) - - end Reduction - - object Contraction: - - extension [T <: Tuple: Labels, V](tensor: Tensor[T, V]) - - /** Computes the outer product of this tensor with another tensor. - * Automatically primes the labels of the resulting tensor to avoid label collisions. - */ - def outerProduct[OtherShape <: Tuple: Labels](other: Tensor[OtherShape, V])(using - primeConcat: PrimeConcat[T, OtherShape], - labels: Labels[primeConcat.Out] - ): Tensor[primeConcat.Out, V] = Tensor( - // Jax outer product flattens, reshape required - Jax.jnp.reshape( - Jax.jnp.outer(tensor.jaxValue, other.jaxValue), - (tensor.shape.dimensions ++ other.shape.dimensions).toPythonProxy - ) - ) - - /** Computes the dot product of this tensor with another tensor along the specified axis. - * The axis must be present in both tensors and will be contracted (removed) from the resulting tensor. - * - * @param axis The axis along which to contract. Must be present in both tensors. - * @param other The other tensor to contract with. - */ - def dot[ - ContractAxis, - OtherShape <: Tuple - ](axis: Axis[ContractAxis])(other: Tensor[OtherShape, V])(using - ev: AxisRemover[T, ContractAxis], - evOther: AxisRemover[OtherShape, ContractAxis] - )(using - primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes], - labelsOut: Labels[primeConcat.Out] - ): Tensor[primeConcat.Out, V] = - val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy) - val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy) - val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy) - - Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair)) - - /** Computes the dot product of this tensor with another tensor along the specified pair of axes. - * The axes must be present in their respective tensors and will be contracted (removed) from the resulting tensor. - * - * @param axis The pair of axes along which to contract. Each axis must be present in its respective tensor. - * @param other The other tensor to contract with. - * - * Example usage: - * {{{ - * val t1: Tensor[("A", "B", "C"), Float] = ??? - * val t2: Tensor[("D", "E, "F), Float] = ??? - * val result = t1.dot(Axis["B" ~ "F])(t2) - * }}} - */ - @targetName("dotOn") - def dot[ - ContractAxisA, - ContractAxisB, - OtherShape <: Tuple - ](axisPair: (Axis[ContractAxisA], Axis[ContractAxisB]))(other: Tensor[OtherShape, V])(using - ev: AxisRemover[T, ContractAxisA], - evOther: AxisRemover[OtherShape, ContractAxisB] - )(using - primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes], - outLabels: Labels[primeConcat.Out] - ): Tensor[primeConcat.Out, V] = - val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy) - val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy) - val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy) - - Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair)) - - end Contraction - - /** Convolution operations. - */ - object Convolution: - - enum Padding: - case SAME, VALID - - type Stride1[S1] = AxisExtent[S1] - - extension [S1: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: InChannel *: EmptyTuple, V]) - - def conv1d[OutChannel: Label]( - kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], - stride: Stride1[S1] | Int = 1, - padding: Padding = Padding.SAME - ): Tensor[S1 *: OutChannel *: EmptyTuple, V] = - require( - input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]), - s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels" - ) - val strides = stride match - case s: Int => Seq(s) - case ae: AxisExtent[S1] => Seq(ae.size) - // JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input. - val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim - val convResult = Jax.lax.conv_general_dilated( - lhs = batchInput, - rhs = kernel.jaxValue, - window_strides = strides.toPythonProxy, - padding = padding.toString, - dimension_numbers = py.Dynamic.global.tuple(Seq("NHC", "HIO", "NHC").toPythonProxy) - ) - val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim - Tensor(unbatchedRes) - - extension [S1: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: OutChannel *: EmptyTuple, V]) - - def transposeConv1d[InChannel: Label]( - kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], - stride: Stride1[S1] | Int = 1, - padding: Padding = Padding.SAME - ): Tensor[S1 *: InChannel *: EmptyTuple, V] = - require( - input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]), - s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}" - ) - val strides = stride match - case s: Int => Seq(s) - case ex: AxisExtent[S1] => Seq(ex.size) - - // kernel -> kernal adjoint: swap in/out channels and flip spatial dims - var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue - kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 - - val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim - val convResult = Jax.lax.conv_transpose( - lhs = batchInput, - rhs = kernelAdjoint, - strides = strides.toPythonProxy, - padding = padding.toString, - dimension_numbers = py.Dynamic.global.tuple(Seq("NHC", "HIO", "NHC").toPythonProxy) - ) - val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim - Tensor(unbatchedRes) - - type Stride2[S1, S2] = (AxisExtent[S1], AxisExtent[S2]) - - extension [S1: Label, S2: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V]) - - def conv2d[OutChannel: Label]( - kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], - stride: Stride2[S1, S2] | Int = 1, - padding: Padding = Padding.SAME - ): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V] = - require( - input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]), - s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels" - ) - val strides = stride match - case s: Int => Seq(s, s) - case (ae1, ae2) => Seq(ae1.size, ae2.size) - // JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input. - val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim - val convResult = Jax.lax.conv_general_dilated( - lhs = batchInput, - rhs = kernel.jaxValue, - window_strides = strides.toPythonProxy, - padding = padding.toString, - dimension_numbers = py.Dynamic.global.tuple(Seq("NHWC", "HWIO", "NHWC").toPythonProxy) - ) - val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim - Tensor(unbatchedRes) - - extension [S1: Label, S2: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V]) - - def transposeConv2d[InChannel: Label]( - kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], - stride: Stride2[S1, S2] | Int = 1, - padding: Padding = Padding.SAME - ): Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V] = - require( - input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]), - s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}" - ) - - // JAX requires input and kernel to have same rank. Add dummy batch dim if needed. - val strides = stride match - case s: Int => Seq(s, s) - case (ae1, ae2) => Seq(ae1.size, ae2.size) - - // kernel -> kernal adjoint: swap in/out channels and flip spatial dims - var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue - kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 - kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 1) // flip S2 - - val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim - val convResult = Jax.lax.conv_transpose( - lhs = batchInput, - rhs = kernelAdjoint, - strides = strides.toPythonProxy, - padding = padding.toString, - dimension_numbers = py.Dynamic.global.tuple(Seq("NHWC", "HWIO", "NHWC").toPythonProxy) - ) - val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim - Tensor(unbatchedRes) - - type Stride3[S1, S2, S3] = (AxisExtent[S1], AxisExtent[S2], AxisExtent[S3]) - - extension [S1: Label, S2: Label, S3: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V]) - - def conv3d[OutChannel: Label]( - kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], - stride: Stride3[S1, S2, S3] | Int = 1, - padding: Padding = Padding.SAME - ): Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V] = - require( - input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]), - s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels" - ) - val strides = stride match - case s: Int => Seq(s, s, s) - case (dim1, dim2, dim3) => Seq(dim1.size, dim2.size, dim3.size) - - // JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input. - // 3D Layout: NDHWC (Batch, Depth, Height, Width, Channel) - val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim - val convResult = Jax.lax.conv_general_dilated( - lhs = batchInput, - rhs = kernel.jaxValue, - window_strides = strides.toPythonProxy, - padding = padding.toString, - dimension_numbers = py.Dynamic.global.tuple(Seq("NDHWC", "DHWIO", "NDHWC").toPythonProxy) - ) - val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim - Tensor(unbatchedRes) - - extension [S1: Label, S2: Label, S3: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V]) - - def transposeConv3d[InChannel: Label]( - kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], - stride: Stride3[S1, S2, S3] | Int = 1, - padding: Padding = Padding.SAME - ): Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V] = - require( - input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]), - s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}" - ) - - val strides = stride match - case s: Int => Seq(s, s, s) - case (ae1, ae2, ae3) => Seq(ae1.size, ae2.size, ae3.size) - - // kernel -> kernel adjoint: swap in/out channels and flip all spatial dims - var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue - kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 (Depth) - kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 1) // flip S2 (Height) - kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 2) // flip S3 (Width) - - val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim - val convResult = Jax.lax.conv_transpose( - lhs = batchInput, - rhs = kernelAdjoint, - strides = strides.toPythonProxy, - padding = padding.toString, - dimension_numbers = py.Dynamic.global.tuple(Seq("NDHWC", "DHWIO", "NDHWC").toPythonProxy) - ) - val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim - Tensor(unbatchedRes) - end Convolution - - object LinearAlgebra: - - // --------------------------------------------------------- - // General operations - // --------------------------------------------------------- - - extension [T <: Tuple: Labels, V](t: Tensor[T, V]) - - def diagonal[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using - ev: AxesRemover[T, (L1, L2)], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes *: L1 *: EmptyTuple, V] = - Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) - - extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) - - def diagonal: Tensor1[L1, V] = t.diagonal(0) - def diagonal(offset: Int): Tensor1[L1, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset)) - - // --------------------------------------------------------- - // IsNumber operations (IsFloat or IsInt) - // --------------------------------------------------------- - - extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - - def trace[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using - ev: AxesRemover[T, (L1, L2)], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) - - extension [L1: Label, L2: Label, V: IsNumber](t: Tensor2[L1, L2, V]) - - def trace: Tensor0[V] = t.trace(0) - def trace(offset: Int): Tensor0[V] = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) - - // --------------------------------------------------------- - // IsFloat operations - // --------------------------------------------------------- - - extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) - - def norm: Tensor0[V] = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) - def inv: Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) - def det[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2])(using - ev: AxesRemover[T, (L1, L2)], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = - // JAX det only works on the last two axes (-2, -1). We must move the user's selected axes to the end. - val moved = Jax.jnp.moveaxis( - t.jaxValue, - source = ev.indices.toPythonProxy, - destination = Seq(-2, -1).toPythonProxy - ) - Tensor(Jax.jnp.linalg.det(moved)) - - extension [L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V]) - - def det: Tensor0[V] = Tensor0(Jax.jnp.linalg.det(t.jaxValue)) - - end LinearAlgebra - - // ----------------------------------------------------------- - // 4. Structural Operations (Isomorphisms) - // Permutations and Views: T1 -> T2 (Size(T1) == Size(T2)) - // ----------------------------------------------------------- - object Structural: - - private object Util: - - type InsertBefore[T <: Tuple, A, B] <: Tuple = T match - case EmptyTuple => B *: EmptyTuple - case A *: tail => B *: A *: tail - case h *: tail => h *: InsertBefore[tail, A, B] - - type InsertAfter[T <: Tuple, A, B] <: Tuple = T match - case EmptyTuple => B *: EmptyTuple - case A *: tail => A *: B *: tail - case h *: tail => h *: InsertAfter[tail, A, B] - - type SliceIndex = Int | List[Int] | Range | Tensor0[Int32] - type ExtractLabel[X] = X match - case AxisAtIndex[l] => l - case AxisAtRange[l] => l - case AxisAtIndices[l] => l - case AxisAtTupleIndices[l, ?] => l - case AxisAtTensorIndex[l] => l - type ExtractLabels[Inputs <: Tuple] = Tuple.Map[Inputs, ExtractLabel] - - trait SliceLabelExtractor[Inputs <: Tuple, Out <: Tuple] - - object SliceLabelExtractor: - - given empty: SliceLabelExtractor[EmptyTuple, EmptyTuple] = - new SliceLabelExtractor[EmptyTuple, EmptyTuple] {} - - // New givens for AxisSelector types - given consAxisAtIndex[L, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[AxisAtIndex[L] *: Tail, L *: TailOut] = - new SliceLabelExtractor[AxisAtIndex[L] *: Tail, L *: TailOut] {} - - given consAxisAtRange[L, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[AxisAtRange[L] *: Tail, TailOut] = - new SliceLabelExtractor[AxisAtRange[L] *: Tail, TailOut] {} - - given consAxisAtIndices[L, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[AxisAtIndices[L] *: Tail, TailOut] = - new SliceLabelExtractor[AxisAtIndices[L] *: Tail, TailOut] {} - - given consAxisAtTupleIndices[L, I <: NonEmptyTuple, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[AxisAtTupleIndices[L, I] *: Tail, TailOut] = - new SliceLabelExtractor[AxisAtTupleIndices[L, I] *: Tail, TailOut] {} - - given consAxisAtTensorIndex[L, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[AxisAtTensorIndex[L] *: Tail, L *: TailOut] = - new SliceLabelExtractor[AxisAtTensorIndex[L] *: Tail, L *: TailOut] {} - - // Keep backward compatibility with tuple syntax - given consInt[L, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[(Axis[L], Int) *: Tail, L *: TailOut] = - new SliceLabelExtractor[(Axis[L], Int) *: Tail, L *: TailOut] {} - - given consTensor0Int[L, Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[(Axis[L], Tensor0[Int32]) *: Tail, L *: TailOut] = - new SliceLabelExtractor[(Axis[L], Tensor0[Int32]) *: Tail, L *: TailOut] {} - - given consSeq[L, SeqT <: Seq[Int], Tail <: Tuple, TailOut <: Tuple](using - tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[(Axis[L], SeqT) *: Tail, TailOut] = - new SliceLabelExtractor[(Axis[L], SeqT) *: Tail, TailOut] {} - - type Swap[T <: Tuple, A, B] <: Tuple = T match - case EmptyTuple => EmptyTuple - case A *: tail => B *: Swap[tail, A, B] - case B *: tail => A *: Swap[tail, A, B] - case h *: tail => h *: Swap[tail, A, B] - - @implicitNotFound("The axis ${L} is already present in the tensor shape ${T}.") - trait AxisAbsent[T, L] - object AxisAbsent: - given [T <: Tuple, L](using NotGiven[Tuple.Contains[T, L] =:= true]): AxisAbsent[T, L] = new AxisAbsent[T, L] {} - - import Util.* - - object TensorWhere: - def where[T <: Tuple: Labels, V]( - condition: Tensor[T, Bool], - x: Tensor[T, V], - y: Tensor[T, V] - ): Tensor[T, V] = - Tensor(Jax.jnp.where(condition.jaxValue, x.jaxValue, y.jaxValue)) - - export TensorWhere.where - - def triu[T <: Tuple: Labels, V](tensor: Tensor[T, V], kthDiagonal: Int = 0): Tensor[T, V] = - Tensor(Jax.jnp.triu(tensor.jaxValue, k = kthDiagonal)) - - def tril[T <: Tuple: Labels, V](tensor: Tensor[T, V], kthDiagonal: Int = 0): Tensor[T, V] = - Tensor(Jax.jnp.tril(tensor.jaxValue, k = kthDiagonal)) - - def stack[L: Label, T <: Tuple: Labels, V]( - tensors: Seq[Tensor[T, V]], - newAxis: Axis[L] - ): Tensor[L *: T, V] = - require(tensors.nonEmpty, "Cannot stack an empty sequence of tensors") - val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy - val stackedJaxValue = Jax.jnp.stack(jaxValuesSeq, axis = 0) - Tensor(stackedJaxValue) - - def stack[NewL, L, T <: Tuple: Labels, V]( - tensors: Seq[Tensor[T, V]], - newAxis: Axis[NewL], - afterAxis: Axis[L] - )(using - newLabel: Label[NewL], - axisIndex: AxisIndex[T, L] - ): Tensor[InsertAfter[T, L, NewL], V] = - require(tensors.nonEmpty, "Cannot stack an empty sequence of tensors") - val axisIdx = axisIndex.index + 1 // we are inserting after the given axis, so shift by 1 - val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy - val stackedJaxValue = Jax.jnp.stack(jaxValuesSeq, axis = axisIdx) - val names = summon[Labels[T]].names - val newNames = names.take(axisIdx) ++ Seq(newLabel.name) ++ names.drop(axisIdx) - given Labels[InsertAfter[T, L, NewL]] with - val names = newNames.toSeq - Tensor(stackedJaxValue) - - def concatenate[L: Label, T <: Tuple: Labels, V]( - tensors: Seq[Tensor[T, V]], - concatAxis: Axis[L] - )(using - axisIndex: AxisIndex[T, L] - ): Tensor[T, V] = - require(tensors.nonEmpty, "Cannot concatenate an empty sequence of tensors") - val axisIdx = axisIndex.index - val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy - val concatenatedJaxValue = Jax.jnp.concatenate(jaxValuesSeq, axis = axisIdx) - Tensor(concatenatedJaxValue) - - def concatenate[L: Label, T <: Tuple: Labels, V]( - t1: Tensor[T, V], - t2: Tensor[T, V], - concatAxis: Axis[L] - )(using - axisIndex: AxisIndex[T, L] - ): Tensor[T, V] = concatenate(Seq(t1, t2), concatAxis) - - trait ValidConcat[T1 <: Tuple, T2 <: Tuple]: - type Out <: Tuple - def index: Int - - object ValidConcat: - type Aux[T1 <: Tuple, T2 <: Tuple, O <: Tuple] = ValidConcat[T1, T2] { type Out = O } - - given recursive[H, T1Tail <: Tuple, T2Tail <: Tuple, OutTail <: Tuple](using - next: ValidConcat.Aux[T1Tail, T2Tail, OutTail] - ): ValidConcat[H *: T1Tail, H *: T2Tail] with - type Out = H *: OutTail - def index: Int = next.index + 1 - - given concatAxis[H1, H2, Tail <: Tuple](using - isDifferent: NotGiven[H1 =:= H2] - ): ValidConcat[H1 *: Tail, H2 *: Tail] with - type Out = (H1 |+| H2) *: Tail - def index: Int = 0 - - def concatenate[T1 <: Tuple, T2 <: Tuple, V, R <: Tuple]( - t1: Tensor[T1, V], - t2: Tensor[T2, V] - )(using - canConcat: ValidConcat.Aux[T1, T2, R], - label: Labels[R] - ): Tensor[R, V] = - val jaxValues = List(t1.jaxValue, t2.jaxValue).toPythonProxy - Tensor(Jax.jnp.concatenate(jaxValues, axis = canConcat.index)) - - type SplitComponents[L, I <: Tuple] <: Tuple = I match - case EmptyTuple => L *: EmptyTuple - case _ *: tail => L *: SplitComponents[L, tail] - - trait Deconcatenator[L]: - type Components <: Tuple - def labels: List[Label[?]] - - object Deconcatenator extends DeconcatenatorLowPriority: - type Aux[L, C <: Tuple] = Deconcatenator[L] { type Components = C } - - given recursive[A, B, CA <: Tuple, CB <: Tuple](using - da: Aux[A, CA], - db: Aux[B, CB] - ): Aux[A |+| B, Tuple.Concat[CA, CB]] = - new Deconcatenator[A |+| B]: - type Components = Tuple.Concat[CA, CB] - def labels = da.labels ++ db.labels - - trait DeconcatenatorLowPriority: - given base[L](using l: Label[L]): Deconcatenator.Aux[L, L *: EmptyTuple] = - new Deconcatenator[L]: - type Components = L *: EmptyTuple - def labels = List(l) - - trait TensorTupleMaker[Components <: Tuple, FullShape <: Tuple, SplitAxis, V]: - type Out <: Tuple - def apply(arrays: Seq[Jax.PyDynamic], compLabels: List[Label[?]], originalLabels: Seq[String], splitIndex: Int): Out - - object TensorTupleMaker: - type Aux[C <: Tuple, F <: Tuple, S, V, O <: Tuple] = - TensorTupleMaker[C, F, S, V] { type Out = O } - - given empty[F <: Tuple, S, V]: Aux[EmptyTuple, F, S, V, EmptyTuple] = - new TensorTupleMaker[EmptyTuple, F, S, V]: - type Out = EmptyTuple - def apply(a: Seq[Jax.PyDynamic], c: List[Label[?]], o: Seq[String], i: Int) = EmptyTuple - - given cons[Head, Tail <: Tuple, F <: Tuple, S, V, NewShape <: Tuple](using - replacer: TupleHelpers.Replacer[F, S, Head] { type Out = NewShape }, - tailMaker: TensorTupleMaker[Tail, F, S, V] - ): Aux[Head *: Tail, F, S, V, Tensor[NewShape, V] *: tailMaker.Out] = - - new TensorTupleMaker[Head *: Tail, F, S, V]: - type Out = Tensor[NewShape, V] *: tailMaker.Out - - def apply(arrays: Seq[Jax.PyDynamic], compLabels: List[Label[?]], originalLabels: Seq[String], splitIndex: Int): Out = - val currentArr = arrays.head - val currentLabel = compLabels.head - val newNames = originalLabels.updated(splitIndex, currentLabel.name).toList - val newLabelsWitness = new Labels[NewShape]: - val names = newNames - val headTensor = Tensor[NewShape, V](currentArr)(using newLabelsWitness) - headTensor *: tailMaker(arrays.tail, compLabels.tail, originalLabels, splitIndex) - - extension [T <: Tuple, V](tensor: Tensor[T, V]) - - def deconcatenate[L, Dims <: Tuple, Comps <: Tuple, Result]( - axis: Axis[L], - dims: Dims - )(using - labels: Labels[T], - axisIndex: AxisIndex[T, L], - decon: Deconcatenator.Aux[L, Comps], - extractor: DimExtractor[Dims], - maker: TensorTupleMaker[Comps, T, L, V] - ): maker.Out = - val orderedSizes = dims.toList.asInstanceOf[List[Any]].map { - case ae: AxisExtent[?] => ae.size - case _ => throw new IllegalArgumentException("Invalid dims format - expected AxisExtent") - } - - require(orderedSizes.size == decon.labels.size, s"Provided ${orderedSizes.size} sizes but axis has ${decon.labels.size} components") - - val splitIndices = orderedSizes.scanLeft(0)(_ + _).tail.init - val pyIndices = me.shadaj.scalapy.py.Dynamic.global.list(splitIndices.toPythonProxy) - val splitArrays = Jax.jnp.split(tensor.jaxValue, pyIndices, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] - val originalNames = summon[Labels[T]].names.toSeq - - maker.apply(splitArrays, decon.labels, originalNames, axisIndex.index) - - /** Splits the tensor along the specified axis at the given indices, returning a tuple of tensors corresponding to the splits. - * - * @param selector of the form Axis[L].at((idx1, idx2, ...)) specifying the axis to split and the indices to split at - * @return the tuple of tensors resulting from the split - */ - def split[L: Label, I <: NonEmptyTuple](selector: AxisAtTupleIndices[L, I])(using - axisIndex: AxisIndex[T, L], - maker: TensorTupleMaker[SplitComponents[L, I], T, L, V], - labels: Labels[T] - ): maker.Out = - val splitList = selector.indices.toList.asInstanceOf[List[Int]] - val pyIndices = me.shadaj.scalapy.py.Dynamic.global.list(splitList.toPythonProxy) - val splitArrays = Jax.jnp.split(tensor.jaxValue, pyIndices, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] - val axisLabelInstance = summon[Label[L]] - val compLabels = List.fill(splitList.size + 1)(axisLabelInstance.asInstanceOf[Label[?]]) - maker.apply(splitArrays, compLabels, labels.names.toSeq, axisIndex.index) - - /** Splits the tensor along the specified axis at the given index, - * returning a tuple of two tensors corresponding to the splits. - * - * @param selector of the form Axis[L].at(idx) specifying the axis to split and the index to split at - * @return a tuple of two tensors resulting from the split - */ - def split[L: Label](selector: AxisAtIndex[L])(using - axisIndex: AxisIndex[T, L], - maker: TensorTupleMaker[L *: L *: EmptyTuple, T, L, V], - labels: Labels[T] - ): maker.Out = - split(AxisAtTupleIndices(selector.axis, Tuple1(selector.index))) - - private def calcPyIndices[Inputs <: Tuple]( - inputs: Inputs, - targetDims: List[Int] - ) = - - val PySlice = py.Dynamic.global.slice - val Colon = PySlice(py.None) - val rank = tensor.shape.rank - val indicesBuffer = collection.mutable.ArrayBuffer.fill[py.Any](rank)(Colon) - - val inputList = inputs.toList.asInstanceOf[List[Any]] - - targetDims.zip(inputList).foreach { case (dimIndex, input) => - val dimSize = tensor.shape.dimensions(dimIndex) - input match - // New AxisSelector types - case AxisAtIndex(_, idx) => - indicesBuffer(dimIndex) = py.Any.from(idx) - case AxisAtRange(_, range) => - indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) - case AxisAtIndices(_, indices) => - indicesBuffer(dimIndex) = indices.map(py.Any.from).toPythonCopy // TODO find out why Copy is needed here - case AxisAtTupleIndices(_, indices) => - indicesBuffer(dimIndex) = indices.toList.asInstanceOf[List[Int]].map(py.Any.from).toPythonCopy - case AxisAtTensorIndex(_, tensorIdx) => - indicesBuffer(dimIndex) = tensorIdx.jaxValue - // Backward compatibility with tuples - case (_, sliceIndex) => - sliceIndex match - case sliceSeq: List[Int] @unchecked => - indicesBuffer(dimIndex) = sliceSeq.map(py.Any.from).toPythonProxy - case range: Range @unchecked => - indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) - case idx: Int => - indicesBuffer(dimIndex) = py.Any.from(idx) - case tensorId: Tensor0[Int32] @unchecked => - indicesBuffer(dimIndex) = tensorId.jaxValue - } - - Jax.Dynamic.global.tuple(indicesBuffer.toSeq.toPythonProxy) - - /** Flattens all axes of the tensor into a single axis. - * The resulting tensor will have a single axis named by concatenating the original axis names with "*". - * - * @return a Tensor1 with the merged axis - */ - def flatten(using labels: Labels[T]): Tensor1[MergeLabels[T], V] = - given Labels[Tuple1[MergeLabels[T]]] with - def names = List(summon[Labels[T]].names.mkString("*")) - Tensor(Jax.jnp.ravel(tensor.jaxValue)) - - /** Flattens the specified axes of the tensor into a single axis. - * The resulting tensor will have the specified axes merged into a single axis named by concatenating the original axis names with "*" - * The other axes remain unchanged. - * - * @param axes the axes to flatten, specified as a tuple of Axis (e.g. (Axis[Ax1], Axis[Ax2])) - * @return a Tensor with the specified axes merged into a single axis - */ - def flatten[AxesTuple <: Tuple, R <: Tuple]( - axes: AxesTuple - )(using - merger: AxesMerger.Aux[T, UnwrapAxes[AxesTuple], R], - labels: Labels[R] - ): Tensor[R, V] = - val permuted = Jax.jnp.transpose(tensor.jaxValue, merger.permutation.toPythonProxy) - - val originalDims = tensor.shape.dimensions - val mergedSize = merger.mergeIndices.map(originalDims).product - - val remainingDims = originalDims.zipWithIndex - .filterNot((d, i) => merger.mergeIndices.contains(i)) - .map(_._1) - - val newDimensions = remainingDims.patch(merger.mergedIndex, Seq(mergedSize), 0) - - Tensor(Jax.jnp.reshape(permuted, newDimensions.toPythonProxy)) - - /** Unflattens splitAxis into a new shape specified by newShape. The other axes remain unchanged. - * - * The user must ensure that the size of splitAxis matches the product of the dimensions in newShape, otherwise a runtime error will occur. - * - * @param splitAxis the axis to unflatten - * @param newShape the new shape to unflatten into, specified as a Shape - * @return a Tensor with the specified axis unflattened into the new shape - */ - def unflatten[SplitL, NewT <: Tuple, R <: Tuple]( - splitAxis: Axis[SplitL], - newShape: Shape[NewT] - )(using - ev: AxisReplacerAll.Aux[T, SplitL, NewT, R], - labels: Labels[R] - ): Tensor[R, V] = - val before = tensor.shape.dimensions.take(ev.index) - val after = tensor.shape.dimensions.drop(ev.index + 1) - val fullNewShape = before ++ newShape.dimensions ++ after - Tensor( - Jax.jnp.reshape( - tensor.jaxValue, - py.Dynamic.global.tuple( - fullNewShape.map(py.Any.from).toPythonProxy - ) - ) - ) - - /** Unflattens the tensor into a new shape specified by newShape. - * - * The user must ensure that the size of the tensor matches the product of the dimensions in newShape, otherwise a runtime error will occur. - * - * @param newShape the new shape to unflatten into, specified as a Shape - * @return a Tensor with the new shape - */ - def unflatten[NewT <: Tuple: Labels]( - newShape: Shape[NewT] - )(using - @implicitNotFound("unflatten without axis can only be used on Tensor1 types.") - ev: T <:< Tuple1[Any] // <--- Ensures this only works on Tensor1 - ): Tensor[NewT, V] = - val fullNewShape = newShape.dimensions - Tensor( - Jax.jnp.reshape( - tensor.jaxValue, - py.Dynamic.global.tuple( - fullNewShape.map(py.Any.from).toPythonProxy - ) - ) - ) - - def transpose[NewOrder <: Tuple, Status <: ValidationResult](newOrder: NewOrder)(using - ev: AxisIndices[T, UnwrapAxes[NewOrder]], - newLabels: Labels[UnwrapAxes[NewOrder]] - )(using - allAxesEv: IsPermutation[T, UnwrapAxes[NewOrder]] - ): Tensor[UnwrapAxes[NewOrder], V] = - val indices = ev.indices - Tensor(Jax.jnp.transpose(tensor.jaxValue, indices.toPythonProxy)) - - /** Splits the tensor along the specified axis at the given indices, returning a sequence of tensors corresponding to the splits. - * - * @param unstackAxis the axis to split, specified as an Axis (e.g. Axis[Ax1]) - * @return a sequence of tensors resulting from the split, each with the specified axis removed - */ - def unstack[L: Label](unstackAxis: Axis[L])(using - labels: Labels[T], - ev: AxisRemover[T, L], - labelR: Labels[ev.RemainingAxes] - ): Seq[Tensor[ev.RemainingAxes, V]] = - val axisIdx = ev.index - val unstacked = Jax.jnp.split(tensor.jaxValue, tensor.shape.dimensions(axisIdx), axis = axisIdx).as[Seq[Jax.PyDynamic]] - unstacked.map(x => Tensor[ev.RemainingAxes, V](x)) - - def chunk[splitL: Label](splitAxis: Axis[splitL], chunkSize: Int)(using - labels: Labels[T], - axisIndex: AxisIndex[T, splitL] - ): Seq[Tensor[T, V]] = - val res = Jax.jnp.split(tensor.jaxValue, chunkSize, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] - res.map(x => Tensor[T, V](x)) - - def slice[Inputs <: Tuple, LabelsToRemove <: Tuple]( - inputs: Inputs - )(using - sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs]], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = - val pyIndices = tensor.calcPyIndices(inputs, ev.indices) - Tensor(tensor.jaxValue.bracketAccess(pyIndices)) - - // Convenience overload for AxisAtIndex - def slice[L, LabelsToRemove <: Tuple]( - selector: AxisAtIndex[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]]], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - - // Convenience overload for AxisAtRange - def slice[L, LabelsToRemove <: Tuple]( - selector: AxisAtRange[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtRange[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]]], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - - // Convenience overload for AxisAtIndices - def slice[L, LabelsToRemove <: Tuple]( - selector: AxisAtIndices[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndices[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]]], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - - // Convenience overload for AxisAtTensorIndex - def slice[L, LabelsToRemove <: Tuple]( - selector: AxisAtTensorIndex[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTensorIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]]], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - - // Convenience overload for AxisAtTupleIndices - def slice[L, U <: NonEmptyTuple, LabelsToRemove <: Tuple]( - selector: AxisAtTupleIndices[L, U] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTupleIndices[L, U]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTupleIndices[L, U]]]], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - - def take[L1, L2: Label]( - axis: Axis[L1] - )( - indices: Tensor1[L2, Int32] - )(using - ev: AxisRemover[T, L1], - labels: Labels[ev.RemainingAxes] - ): Tensor[Tuple.Concat[Tuple1[L2], ev.RemainingAxes], V] = - val result = Jax.jnp.take(tensor.jaxValue, indices.jaxValue, axis = ev.index) - Tensor(result) - - def set[Inputs <: Tuple, LabelsToRemove <: Tuple]( - inputs: Inputs - )(using - sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs]], - labels: Labels[T] - )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = - val pyIndices = tensor.calcPyIndices(inputs, ev.indices) - val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value.jaxValue) - Tensor(result) - - // Convenience overload for Float - def set[Inputs <: Tuple, LabelsToRemove <: Tuple]( - inputs: Inputs - )(using - sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], - ev: AxesConditionalRemover.Aux[T, LabelsToRemove, ExtractLabels[Inputs], EmptyTuple], - labels: Labels[T] - )(value: Float): Tensor[T, V] = - val pyIndices = tensor.calcPyIndices(inputs, ev.indices) - val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value) - Tensor(result) - - // Convenience overload for AxisAtIndex - def set[L, LabelsToRemove <: Tuple]( - selector: AxisAtIndex[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]]], - labels: Labels[T] - )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) - - // Convenience overload for AxisAtRange - def set[L, LabelsToRemove <: Tuple]( - selector: AxisAtRange[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtRange[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]]], - labels: Labels[T] - )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) - - // Convenience overload for AxisAtIndices - def set[L, LabelsToRemove <: Tuple]( - selector: AxisAtIndices[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndices[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]]], - labels: Labels[T] - )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) - - // Convenience overload for AxisAtTensorIndex - def set[L, LabelsToRemove <: Tuple]( - selector: AxisAtTensorIndex[L] - )(using - sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTensorIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]]], - labels: Labels[T] - )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) - - def rearrange[Axes <: Tuple, Status <: ValidationResult](newOrder: Axes)(using - Labels[UnwrapAxes[Axes]] - )(using - computer: ComputeMissing[UnwrapAxes[Axes], T, EmptyTuple, Status], - guard: CheckValid[Status] - ): Tensor[UnwrapAxes[Axes], V] = - rearrange[Axes, EmptyTuple, Status](newOrder, EmptyTuple) - - // Convenience overload for 1 dims (to support error messages with single axis) - inline def rearrange[Axes <: Tuple, L1, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[Tuple1[AxisExtent[L1]]], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[Tuple1[AxisExtent[L1]]]): Tensor[UnwrapAxes[Axes], V] = - rearrange(newOrder, Tuple1(d1)) - - // Convenience overload for 2 dims - inline def rearrange[Axes <: Tuple, L1, L2, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1], d2: AxisExtent[L2])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[(AxisExtent[L1], AxisExtent[L2])], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[(AxisExtent[L1], AxisExtent[L2])]): Tensor[UnwrapAxes[Axes], V] = - rearrange(newOrder, (d1, d2)) - - // Convenience overload for 3 dims - inline def rearrange[Axes <: Tuple, L1, L2, L3, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1], d2: AxisExtent[L2], d3: AxisExtent[L3])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3])], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3])]): Tensor[UnwrapAxes[Axes], V] = - rearrange(newOrder, (d1, d2, d3)) - - // Convenience overload for 4 dims - inline def rearrange[Axes <: Tuple, L1, L2, L3, L4, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1], d2: AxisExtent[L2], d3: AxisExtent[L3], d4: AxisExtent[L4])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3], AxisExtent[L4])], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3], AxisExtent[L4])]): Tensor[UnwrapAxes[Axes], V] = - rearrange(newOrder, (d1, d2, d3, d4)) - - def rearrange[Axes <: Tuple, Dims <: Tuple, Status <: ValidationResult]( - newOrder: Axes, - dims: Dims - )(using - computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[Dims], Status], - guard: CheckValid[Status] - )(using - newLabels: Labels[UnwrapAxes[Axes]], - extractor: DimExtractor[Dims] - ): Tensor[UnwrapAxes[Axes], V] = - def cleanPatternPrime(pattern: String): String = - // Support dimwit.Prime by replacing ' with "Prime" - pattern.replaceAll( - "'", - "Prime" - ) - def createEinopsPattern(fromPattern: String, toPattern: String): String = - def cleanPatternStar(pattern: String): String = - // to replace all a*b*c in pattern with (a b c), example: - // "a*b*c d e f*g h" -> "(a b c) d e (f g) h" - val regex = raw"([a-zA-Z0-9_]+(\*[a-zA-Z0-9_]+)+)".r - regex.replaceAllIn( - pattern, - _.group(1).split("\\*").mkString("(", " ", ")") - ) - def cleanPatternPlus(pattern: String): String = - // Support dimwit.|+| by replacing + with underlines - val regex = raw"([a-zA-Z0-9_]+(\+[a-zA-Z0-9_]+)+)".r - regex.replaceAllIn( - pattern, - _.group(1).replace("+", "_") - ) - def cleanPattern(pattern: String): String = - cleanPatternPlus(cleanPatternStar(cleanPatternPrime(pattern))) - s"${cleanPattern(fromPattern)} -> ${cleanPattern(toPattern)}" - val fromPattern = tensor.shape.labels.mkString(" ") - val toPattern = newLabels.names.mkString(" ") - val pattern = createEinopsPattern(fromPattern, toPattern) - val dimSizesMap = extractor.extract(dims) - val cleanDimSizesMap = dimSizesMap.map { case (k, v) => - val newKey = cleanPatternPrime(k) - (newKey, v) - } - Tensor( - Einops.rearrange( - tensor.jaxValue, - pattern, - kwargsMap = cleanDimSizesMap - ) - ) - - def broadcastTo[O <: Tuple: Labels](newShape: Shape[O])(using - labels: Labels[T], - ev: StrictSubset[T, O] - ): Tensor[O, V] = - /* Disallow implicit broadcasting where an *existing* axis changes size (implicitly). - * dimwit broadcasting only adds missing axes, never changes existing ones. - * - * This is a required check to prevent implicit broadcasting across dimwit. - * If this check is not explicitly present, Jax.jnp.broadcast_to would implicit broadcast.*/ - def disallowImplicitShapeBroadcasting(): Unit = - val tAxesDims = tensor.axes.zip(tensor.shape.dimensions).toMap - val newShapeAxesDims = newShape.labels.zip(newShape.dimensions).toMap - tensor.axes.foreach(axisName => - require( - tAxesDims(axisName) == newShapeAxesDims(axisName), - s"Broadcasting only adds missing axes. Present axes must have the same size. Axis ${axisName} has size ${tAxesDims(axisName)} in the current tensor but size ${newShapeAxesDims(axisName)} in the target shape." - ) - ) - - disallowImplicitShapeBroadcasting() // Make dimwit coders, good coders :) - - val t = tensor - - val currentNames = summon[Labels[T]].names - val targetNames = summon[Labels[O]].names - - val targetOrder = targetNames.filter(currentNames.contains) - val permutation = targetOrder.map(n => currentNames.indexOf(n)) - - val alignedJax = - if permutation != currentNames.indices.toList then Jax.jnp.transpose(t.jaxValue, permutation.toPythonProxy) - else t.jaxValue - - val currentShapeMap = currentNames.zip(t.shape.dimensions).toMap - - val intermediateShape = targetNames.map { name => - currentShapeMap.getOrElse(name, 1) - } - - val reshapedJax = Jax.jnp.reshape(alignedJax, intermediateShape.toPythonProxy) - Tensor(Jax.jnp.broadcast_to(reshapedJax, newShape.dimensions.toPythonProxy)) - - def relabel[OldLabel: Label, NewLabel: Label]( - rename: (Axis[OldLabel], Axis[NewLabel]) - )(using - ev: AxisReplacer[T, OldLabel, NewLabel], - newLabels: Labels[ev.NewShape] - ): Tensor[ev.NewShape, V] = Tensor(tensor.jaxValue) - - def retag[newT <: Tuple](using newLabels: Labels[newT]): Tensor[newT, V] = - Tensor(tensor.jaxValue)(using newLabels) - - def relabelAll[newT <: Tuple]( - newAxes: newT - )(using - newLabels: Labels[UnwrapAxes[newT]], - @implicitNotFound("Cannot convert tensor of shape ${T} to shape ${newT} due to size mismatch.") - evSameSize: Tuple.Size[newT] =:= Tuple.Size[T] - ): Tensor[UnwrapAxes[newT], V] = Tensor[UnwrapAxes[newT], V](tensor.jaxValue) - - def swap[L1: Label, L2: Label]( - axis1: Axis[L1], - axis2: Axis[L2] - )(using - labels: Labels[T], - axisIndex1: AxisIndex[T, L1], - axisIndex2: AxisIndex[T, L2] - ): Tensor[Swap[T, L1, L2], V] = - given Labels[Swap[T, L1, L2]] with - def names = - val originalNames = summon[Labels[T]].names - val ax1Name = summon[Label[L1]].name - val ax2Name = summon[Label[L2]].name - originalNames.map { - case n if n == ax1Name => ax2Name - case n if n == ax2Name => ax1Name - case n => n - } - Tensor(Jax.jnp.swapaxes(tensor.jaxValue, axisIndex1.index, axisIndex2.index)) - - def appendAxis[L: Label](axis: Axis[L])(using labels: Labels[T], ev: AxisAbsent[T, L]): Tensor[Tuple.Concat[T, Tuple1[L]], V] = - val newShape = tensor.shape.dimensions :+ 1 - Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) - - def prependAxis[L: Label](axis: Axis[L])(using labels: Labels[T], ev: AxisAbsent[T, L]): Tensor[Tuple.Concat[Tuple1[L], T], V] = - val newShape = 1 +: tensor.shape.dimensions - Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) - - def squeeze[L: Label](axis: Axis[L])(using - ev: AxisRemover[T, L], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = - require( - tensor.shape.dimensions(ev.index) == 1, - s"Cannot squeeze axis ${summon[Label[L]].name} of size ${tensor.shape.dimensions(ev.index)}" - ) - Tensor(Jax.jnp.squeeze(tensor.jaxValue, axis = ev.index)) - - extension [L: Label, V](tensor: Tensor1[L, V]) - def roll(shift: Int): Tensor1[L, V] = - Tensor(Jax.jnp.roll(tensor.jaxValue, shift = shift, axis = 0)) - - end Structural - - // ----------------------------------------------------------- - // 5. Functional Operations (Higher Order) - // Lifting functions over axes - // ----------------------------------------------------------- - object Functional: - - object ZipVmap: - - type TensorsOf[Shapes <: Tuple, Values <: Tuple] <: Tuple = (Shapes, Values) match - case (EmptyTuple, EmptyTuple) => EmptyTuple - case ((shapeHead *: shapeTail), (valueHead *: valueTail)) => Tensor[shapeHead, valueHead] *: TensorsOf[shapeTail, valueTail] - - type ExtractShape[T] = T match - case Tensor[s, v] => s - - type ExtractValue[T] = T match - case Tensor[s, v] => v - - type ShapesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractShape] - type ValuesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractValue] - - def zipvmap[L: Label, Inputs <: Tuple, OutShape <: Tuple: Labels, OutV]( - axis: Axis[L] - )( - tensors: Inputs // This is a Tuple of Tensors - )(using - ev: SharedAxisRemover[ShapesOf[Inputs], L] - )( - f: TensorsOf[ev.RemainingAxes, ValuesOf[Inputs]] => Tensor[OutShape, OutV] - ): Tensor[L *: OutShape, OutV] = - val fpy = (args: py.Dynamic) => - OnError.traceStack: - val tensorList = args.as[Seq[py.Dynamic]].zip(ev.shapesLabels).map: (jaxArr, labels) => - Tensor(jaxArr)(using LabelsImpl(labels)) - - val inputTuple = Tuple.fromArray(tensorList.toArray) - val result = f(inputTuple.asInstanceOf[TensorsOf[ev.RemainingAxes, ValuesOf[Inputs]]]) - result.jaxValue - - val jaxInputs = py.Dynamic.global.tuple(tensors.toArray.map(_.asInstanceOf[Tensor[?, ?]].jaxValue).toPythonProxy) - val indicesAsTuple = py.Dynamic.global.tuple(ev.indices.toPythonProxy) - val jaxResult = Jax.jax_helper.zipvmap( - fpy, - indicesAsTuple - )(jaxInputs) - - Tensor(jaxResult) - - export ZipVmap.zipvmap - - extension [T <: Tuple: Labels, V](t: Tensor[T, V]) - - def vmap[VmapAxis: Label, OuterShape <: Tuple: Labels, V2]( - axis: Axis[VmapAxis] - )(using - ev: AxisRemover[T, VmapAxis] - )( - f: Tensor[ev.RemainingAxes, V] => Tensor[OuterShape, V2] - )(using - labels: Labels[ev.RemainingAxes] - ): Tensor[VmapAxis *: OuterShape, V2] = - val fpy = (jxpr: Jax.PyDynamic) => - OnError.traceStack: - val innerTensor = Tensor[ev.RemainingAxes, V](jxpr) - val result = f(innerTensor) - result.jaxValue - - Tensor(Jax.jax_helper.vmap(fpy, ev.index)(t.jaxValue)) - - def vapply[L: Label, NewL, R <: Tuple, NewV]( - axis: Axis[L] - )( - f: Tensor[Tuple1[L], V] => Tensor[Tuple1[NewL], NewV] - )(using - ev: AxisReplacer.Aux[T, L, NewL, R], - labels: Labels[R] - ): Tensor[R, NewV] = - val fpy = (jxpr: Jax.PyDynamic) => - OnError.traceStack: - val inputTensor = Tensor[Tuple1[L], V](jxpr) - val result = f(inputTensor) - result.jaxValue - - Tensor( - Jax.jnp.apply_along_axis( - fpy, - ev.index, - t.jaxValue - ) - ) - - def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])( - other: Tensor[T2, V] - )(using - ev: SharedAxisRemover[(T, T2), L] - )( - f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV] - ): Tensor[L *: OutShape, OutV] = - ZipVmap.zipvmap(axis)(t, other)(f) - - def vreduce[L: Label]( - axis: Axis[L] - )( - f: Tensor[Tuple1[L], V] => Tensor0[V] - )(using - ev: AxisRemover[T, L], - labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = - val fpy = (jxpr: Jax.PyDynamic) => - OnError.traceStack: - val inputTensor = Tensor[Tuple1[L], V](jxpr) - val result = f(inputTensor) - result.jaxValue - - Tensor( - Jax.jnp.apply_along_axis( - fpy, - ev.index, - t.jaxValue - ) - ) - - end Functional - - export Elementwise.* - export Reduction.* - export Contraction.* - export Convolution.* - export LinearAlgebra.* - export Structural.* - export Functional.* - - // ----------------------------------------------------------- - // Common specialized operation names - // ----------------------------------------------------------- - - object Tensor0Ops: - - private inline def checkTracer[V, R](scalar: Tensor0[V]): Unit = - require( - !scalar.isTracer, - """ - | Cannot convert a JAX Tracer to a scalar value. Tensor0 is part of a JAX computation graph (e.g., inside vmap or a jitted function). - | Common mistakes leading to this error: - | - calling .slice(t0.item) rather than .slice(t0); breaking the computation graph unintentionally. - |""".stripMargin - ) - - extension (scalar: Tensor0[Bool]) - def item: Boolean = - checkTracer(scalar) - scalar.jaxValue.item().as[Boolean] - - extension (scalar: Tensor0[Int8]) - def item: Byte = - checkTracer(scalar) - scalar.jaxValue.item().as[Byte] - - extension (scalar: Tensor0[Int16]) - def item: Short = - checkTracer(scalar) - scalar.jaxValue.item().as[Int].toShort - - extension (scalar: Tensor0[Int32]) - def item: Int = - checkTracer(scalar) - scalar.jaxValue.item().as[Int] - - extension (scalar: Tensor0[Int64]) - def item: Long = - checkTracer(scalar) - scalar.jaxValue.item().as[Long] - - extension (scalar: Tensor0[Float32]) - def item: Float = - checkTracer(scalar) - scalar.jaxValue.item().as[Float] - - extension (scalar: Tensor0[Float64]) - def item: Double = - checkTracer(scalar) - scalar.jaxValue.item().as[Double] - - object ValueOps: - - import Elementwise.+! - - extension [V: IsNumber](t: Tensor0[V]) - - def +(t2: Tensor0[V]): Tensor0[V] = TensorOps.add(t, t2) - def -(t2: Tensor0[V]): Tensor0[V] = TensorOps.subtract(t, t2) - def *(t2: Tensor0[V]): Tensor0[V] = TensorOps.multiply(t, t2) - - extension [V: IsFloating](t: Tensor0[V]) - - def /(scalar: Tensor0[V]): Tensor0[V] = TensorOps.divide(t, scalar) - - extension (scalar: Float) - - def +[V: IsNumber](t: Tensor0[V]): Tensor0[V] = add(Tensor0.likeDType(t)(scalar), t) - def +![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(add) - - def -[V: IsNumber](t: Tensor0[V]): Tensor0[V] = subtract(Tensor0.likeDType(t)(scalar), t) - def -![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(subtract) - - def *[V: IsNumber](t: Tensor0[V]): Tensor0[V] = multiply(Tensor0.likeDType(t)(scalar), t) - def *![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(multiply) - - extension (scalar: Float) - - def /[V: IsFloating](t: Tensor0[V]): Tensor0[V] = divide(Tensor0.likeDType(t)(scalar), t) - def /![T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(divide) - - object Tensor1Ops: - - extension [L, V](t: Tensor1[L, V]) - - def relabelTo[NewL: Label](newAxis: Axis[NewL]): Tensor1[NewL, V] = Tensor[Tuple1[NewL], V](t.jaxValue) - - // TODO generalize to TensorN (like slice) - def dynamicSlice( - dynamicStart: Tensor0[Int32], - staticSize: Int - )(using - label: Label[L] - ): Tensor1[L, V] = - // TODO understand why toPythonCopy is needed and toPythonProxy fails! - Tensor(Jax.lax.dynamic_slice(t.jaxValue, Seq(dynamicStart.jaxValue).toPythonCopy, Seq(staticSize).toPythonCopy)) - - extension [L, V, X](t: Tensor1[L, V])(using ev: HasScalar[V, X]) - /** Converts a Tensor1 to a Scala Array. - * The user must ensure that the tensor is not a JAX Tracer - * (i.e., it is not part of a JAX computation graph) before calling this method, - * otherwise a runtime error will occur. - */ - def toArray: Array[X] = - require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") - ev.readFlat(t.jaxValue) - - object Tensor2Ops: - - extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) - - // Support .transpose without arguments for 2D tensors while keeping (not shadowing) the general .transpose with arguments - def transpose: Tensor2[L2, L1, V] = t.transpose(Axis[L2], Axis[L1]) - def transpose(axis2: Axis[L2], axis1: Axis[L1]): Tensor2[L2, L1, V] = TensorOps.Structural.transpose(t)(axis2, axis1) - - extension [L1, L2, V, X](t: Tensor2[L1, L2, V])(using ev: HasScalar[V, X]) - /** Converts a Tensor2 to a nested Scala Array (Array of Arrays). - * The user must ensure that the tensor is not a JAX Tracer - * (i.e., it is not part of a JAX computation graph) before calling this method, - * otherwise a runtime error will occur. - */ - def toArray: Array[Array[X]] = - require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") - given scala.reflect.ClassTag[X] = ev.classTag - ev.readFlat(t.jaxValue).grouped(t.shape.dimensions(1)).toArray - - object Tensor3Ops: - - extension [L1, L2, L3, V, X](t: Tensor3[L1, L2, L3, V])(using ev: HasScalar[V, X]) - /** Converts a Tensor3 to a nested Scala Array (Array of Arrays of Arrays). - * The user must ensure that the tensor is not a JAX Tracer - * (i.e., it is not part of a JAX computation graph) before calling this method, - * otherwise a runtime error will occur. - */ - def toArray: Array[Array[Array[X]]] = - require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") - given scala.reflect.ClassTag[X] = ev.classTag - val d1 = t.shape.dimensions(1); val d2 = t.shape.dimensions(2) - ev.readFlat(t.jaxValue).grouped(d1 * d2).map(_.grouped(d2).toArray).toArray - - export Tensor0Ops.* + export tensorops.Tensor0Ops.* export ValueOps.* - export Tensor1Ops.* - export Tensor2Ops.* - export Tensor3Ops.* + export tensorops.Tensor1Ops.* + export tensorops.Tensor2Ops.* + export tensorops.Tensor3Ops.* end TensorOps - -object TensorOpsUtil: - - import TensorOps.Structural.broadcastTo - - @implicitNotFound("Cannot broadcast tensors of shapes ${T1} and ${T2}. If same shape no broadcasting allowed!") - sealed trait Broadcast[T1 <: Tuple, T2 <: Tuple, V]: - type Out <: Tuple - given labelsOut: Labels[Out] - def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]): (Tensor[Out, V], Tensor[Out, V]) - def applyTo[V2](t1: Tensor[T1, V], t2: Tensor[T2, V])(f: (Tensor[Out, V], Tensor[Out, V]) => Tensor[Out, V2]): Tensor[Out, V2] = - val (bt1, bt2) = broadcast(t1, t2) - f(bt1, bt2) - - object Broadcast extends BroadcastLowPriority: - - given broadcastLeft[T1 <: Tuple: Labels, T2 <: Tuple: Labels, V](using - StrictSubset[T2, T1] - ): Broadcast[T1, T2, V] with - type Out = T1 - val labelsOut = summon[Labels[T1]] - def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]) = - (t1, t2.broadcastTo[T1](t1.shape)) - - trait BroadcastLowPriority: - given broadcastRight[T1 <: Tuple: Labels, T2 <: Tuple: Labels, V](using - StrictSubset[T1, T2] - ): Broadcast[T1, T2, V] with - type Out = T2 - val labelsOut = summon[Labels[T2]] - def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]) = - (t1.broadcastTo[T2](t2.shape), t2) - -end TensorOpsUtil diff --git a/core/src/main/scala/dimwit/tensor/ValueOps.scala b/core/src/main/scala/dimwit/tensor/ValueOps.scala new file mode 100644 index 00000000..f5c23767 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/ValueOps.scala @@ -0,0 +1,37 @@ +package dimwit.tensor + +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.tensorops.ElementWiseOps.add +import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.tensorops.ElementWiseOps.divide + +object ValueOps: + + extension [V: IsNumber](t: Tensor0[V]) + + def +(t2: Tensor0[V]): Tensor0[V] = TensorOps.add(t, t2) + def -(t2: Tensor0[V]): Tensor0[V] = TensorOps.subtract(t, t2) + def *(t2: Tensor0[V]): Tensor0[V] = TensorOps.multiply(t, t2) + + extension [V: IsFloating](t: Tensor0[V]) + + def /(scalar: Tensor0[V]): Tensor0[V] = TensorOps.divide(t, scalar) + + extension (scalar: Float) + + def +[V: IsNumber](t: Tensor0[V]): Tensor0[V] = add(Tensor0.likeDType(t)(scalar), t) + def +![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(add) + + def -[V: IsNumber](t: Tensor0[V]): Tensor0[V] = subtract(Tensor0.likeDType(t)(scalar), t) + def -![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(subtract) + + def *[V: IsNumber](t: Tensor0[V]): Tensor0[V] = multiply(Tensor0.likeDType(t)(scalar), t) + def *![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(multiply) + + extension (scalar: Float) + + def /[V: IsFloating](t: Tensor0[V]): Tensor0[V] = divide(Tensor0.likeDType(t)(scalar), t) + def /![T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(divide) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala new file mode 100644 index 00000000..5074592f --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala @@ -0,0 +1,98 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.Label +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Axis +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.TupleHelpers.PrimeConcat +import scala.annotation.targetName + +object ContractionOps: + + extension [T <: Tuple: Labels, V](tensor: Tensor[T, V]) + + /** Computes the outer product of this tensor with another tensor. + * Automatically primes the labels of the resulting tensor to avoid label collisions. + */ + def outerProduct[OtherShape <: Tuple: Labels](other: Tensor[OtherShape, V])(using + primeConcat: PrimeConcat[T, OtherShape], + labels: Labels[primeConcat.Out] + ): Tensor[primeConcat.Out, V] = Tensor( + // Jax outer product flattens, reshape required + Jax.jnp.reshape( + Jax.jnp.outer(tensor.jaxValue, other.jaxValue), + (tensor.shape.dimensions ++ other.shape.dimensions).toPythonProxy + ) + ) + + /** Computes the dot product of this tensor with another tensor along the specified axis. + * The axis must be present in both tensors and will be contracted (removed) from the resulting tensor. + * + * @param axis The axis along which to contract. Must be present in both tensors. + * @param other The other tensor to contract with. + */ + def dot[ + ContractAxis, + OtherShape <: Tuple + ](axis: Axis[ContractAxis])(other: Tensor[OtherShape, V])(using + ev: AxisRemover[T, ContractAxis], + evOther: AxisRemover[OtherShape, ContractAxis] + )(using + primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes], + labelsOut: Labels[primeConcat.Out] + ): Tensor[primeConcat.Out, V] = + val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy) + val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy) + val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy) + + Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair)) + + /** Computes the dot product of this tensor with another tensor along the specified pair of axes. + * The axes must be present in their respective tensors and will be contracted (removed) from the resulting tensor. + * + * @param axis The pair of axes along which to contract. Each axis must be present in its respective tensor. + * @param other The other tensor to contract with. + * + * Example usage: + * {{{ + * val t1: Tensor[("A", "B", "C"), Float] = ??? + * val t2: Tensor[("D", "E, "F), Float] = ??? + * val result = t1.dot(Axis["B" ~ "F])(t2) + * }}} + */ + @targetName("dotOn") + def dot[ + ContractAxisA, + ContractAxisB, + OtherShape <: Tuple + ](axisPair: (Axis[ContractAxisA], Axis[ContractAxisB]))(other: Tensor[OtherShape, V])(using + ev: AxisRemover[T, ContractAxisA], + evOther: AxisRemover[OtherShape, ContractAxisB] + )(using + primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes], + outLabels: Labels[primeConcat.Out] + ): Tensor[primeConcat.Out, V] = + val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy) + val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy) + val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy) + + Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala new file mode 100644 index 00000000..9a4f84da --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala @@ -0,0 +1,214 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.Label +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Axis +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.AxisExtent +import dimwit.tensor.TensorOps.swap + +object ConvolutionOps: + + enum Padding: + case SAME, VALID + + type Stride1[S1] = AxisExtent[S1] + + extension [S1: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: InChannel *: EmptyTuple, V]) + + def conv1d[OutChannel: Label]( + kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], + stride: Stride1[S1] | Int = 1, + padding: Padding = Padding.SAME + ): Tensor[S1 *: OutChannel *: EmptyTuple, V] = + require( + input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]), + s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels" + ) + val strides = stride match + case s: Int => Seq(s) + case ae: AxisExtent[S1] => Seq(ae.size) + // JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input. + val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim + val convResult = Jax.lax.conv_general_dilated( + lhs = batchInput, + rhs = kernel.jaxValue, + window_strides = strides.toPythonProxy, + padding = padding.toString, + dimension_numbers = py.Dynamic.global.tuple(Seq("NHC", "HIO", "NHC").toPythonProxy) + ) + val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim + Tensor(unbatchedRes) + + extension [S1: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: OutChannel *: EmptyTuple, V]) + + def transposeConv1d[InChannel: Label]( + kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], + stride: Stride1[S1] | Int = 1, + padding: Padding = Padding.SAME + ): Tensor[S1 *: InChannel *: EmptyTuple, V] = + require( + input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]), + s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}" + ) + val strides = stride match + case s: Int => Seq(s) + case ex: AxisExtent[S1] => Seq(ex.size) + + // kernel -> kernal adjoint: swap in/out channels and flip spatial dims + var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue + kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 + + val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim + val convResult = Jax.lax.conv_transpose( + lhs = batchInput, + rhs = kernelAdjoint, + strides = strides.toPythonProxy, + padding = padding.toString, + dimension_numbers = py.Dynamic.global.tuple(Seq("NHC", "HIO", "NHC").toPythonProxy) + ) + val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim + Tensor(unbatchedRes) + + type Stride2[S1, S2] = (AxisExtent[S1], AxisExtent[S2]) + + extension [S1: Label, S2: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V]) + + def conv2d[OutChannel: Label]( + kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], + stride: Stride2[S1, S2] | Int = 1, + padding: Padding = Padding.SAME + ): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V] = + require( + input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]), + s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels" + ) + val strides = stride match + case s: Int => Seq(s, s) + case (ae1, ae2) => Seq(ae1.size, ae2.size) + // JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input. + val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim + val convResult = Jax.lax.conv_general_dilated( + lhs = batchInput, + rhs = kernel.jaxValue, + window_strides = strides.toPythonProxy, + padding = padding.toString, + dimension_numbers = py.Dynamic.global.tuple(Seq("NHWC", "HWIO", "NHWC").toPythonProxy) + ) + val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim + Tensor(unbatchedRes) + + extension [S1: Label, S2: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V]) + + def transposeConv2d[InChannel: Label]( + kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], + stride: Stride2[S1, S2] | Int = 1, + padding: Padding = Padding.SAME + ): Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V] = + require( + input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]), + s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}" + ) + + // JAX requires input and kernel to have same rank. Add dummy batch dim if needed. + val strides = stride match + case s: Int => Seq(s, s) + case (ae1, ae2) => Seq(ae1.size, ae2.size) + + // kernel -> kernal adjoint: swap in/out channels and flip spatial dims + var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue + kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 + kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 1) // flip S2 + + val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim + val convResult = Jax.lax.conv_transpose( + lhs = batchInput, + rhs = kernelAdjoint, + strides = strides.toPythonProxy, + padding = padding.toString, + dimension_numbers = py.Dynamic.global.tuple(Seq("NHWC", "HWIO", "NHWC").toPythonProxy) + ) + val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim + Tensor(unbatchedRes) + + type Stride3[S1, S2, S3] = (AxisExtent[S1], AxisExtent[S2], AxisExtent[S3]) + + extension [S1: Label, S2: Label, S3: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V]) + + def conv3d[OutChannel: Label]( + kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], + stride: Stride3[S1, S2, S3] | Int = 1, + padding: Padding = Padding.SAME + ): Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V] = + require( + input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]), + s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels" + ) + val strides = stride match + case s: Int => Seq(s, s, s) + case (dim1, dim2, dim3) => Seq(dim1.size, dim2.size, dim3.size) + + // JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input. + // 3D Layout: NDHWC (Batch, Depth, Height, Width, Channel) + val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim + val convResult = Jax.lax.conv_general_dilated( + lhs = batchInput, + rhs = kernel.jaxValue, + window_strides = strides.toPythonProxy, + padding = padding.toString, + dimension_numbers = py.Dynamic.global.tuple(Seq("NDHWC", "DHWIO", "NDHWC").toPythonProxy) + ) + val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim + Tensor(unbatchedRes) + + extension [S1: Label, S2: Label, S3: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V]) + + def transposeConv3d[InChannel: Label]( + kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], + stride: Stride3[S1, S2, S3] | Int = 1, + padding: Padding = Padding.SAME + ): Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V] = + require( + input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]), + s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}" + ) + + val strides = stride match + case s: Int => Seq(s, s, s) + case (ae1, ae2, ae3) => Seq(ae1.size, ae2.size, ae3.size) + + // kernel -> kernel adjoint: swap in/out channels and flip all spatial dims + var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue + kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 (Depth) + kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 1) // flip S2 (Height) + kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 2) // flip S3 (Width) + + val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim + val convResult = Jax.lax.conv_transpose( + lhs = batchInput, + rhs = kernelAdjoint, + strides = strides.toPythonProxy, + padding = padding.toString, + dimension_numbers = py.Dynamic.global.tuple(Seq("NDHWC", "DHWIO", "NDHWC").toPythonProxy) + ) + val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim + Tensor(unbatchedRes) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala new file mode 100644 index 00000000..a7ecdbdd --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala @@ -0,0 +1,124 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast + +object ElementWiseOps: + // --------------------------------------------------------- + // General operations + // --------------------------------------------------------- + + /** Elementwise maximum of two tensors. */ + def maximum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.maximum(t1.jaxValue, t2.jaxValue)) + + /** Elementwise minimum of two tensors. */ + def minimum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue)) + + extension [T <: Tuple: Labels, V](t: Tensor[T, V]) + + def <(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) + def <=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) + def >(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) + def >=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) + + /** Checks full array equality, returns true if all elements are equal */ + def ===(other: Tensor[T, V]): Tensor0[Bool] = Tensor0(Jax.jnp.array_equal(t.jaxValue, other.jaxValue)) + + /** Elementwise equality, returns a tensor of bools indicating which elements are equal */ + def elementEquals(other: Tensor[T, V]): Tensor[T, Bool] = + require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}") + Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue)) + + def asBool: Tensor[T, Bool] = t.asType(VType[Bool]) + def asBoolean[NewV: IsBoolean](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + def asInt32: Tensor[T, Int32] = t.asType(VType[Int32]) + def asInt[NewV: IsInteger](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + def asFloat32: Tensor[T, Float32] = t.asType(VType[Float32]) + def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + + // --------------------------------------------------------- + // IsNumber operations (IsFloat or IsInt) + // --------------------------------------------------------- + + def add[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) + def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue)) + + def negate[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue)) + def subtract[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue)) + def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, t2.jaxValue)) + + def multiply[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) + def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue)) + + extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) + + def +(other: Tensor[T, V]): Tensor[T, V] = add(t, other) + def -(other: Tensor[T, V]): Tensor[T, V] = subtract(t, other) + def *(other: Tensor[T, V]): Tensor[T, V] = multiply(t, other) + + extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) + + def +![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(add) + + def unary_- : Tensor[T, V] = negate(t) + def -![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(subtract) + + def *![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(multiply) + def scale(other: Tensor0[V]): Tensor[T, V] = multiplyScalar(t, other) + + def abs: Tensor[T, V] = Tensor(Jax.jnp.abs(t.jaxValue)) + def sign: Tensor[T, V] = Tensor(Jax.jnp.sign(t.jaxValue)) + def clip(min: Tensor0[V], max: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.clip(t.jaxValue, min.jaxValue, max.jaxValue)) + def pow(n: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.power(t.jaxValue, n.jaxValue)) + + // --------------------------------------------------------- + // IsFloat operations + // --------------------------------------------------------- + + def divide[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue)) + def divideScalar[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue)) + + extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) + + def /(other: Tensor[T, V]): Tensor[T, V] = divide(t, other) + def /![O <: Tuple](other: Tensor[O, V])(using join: Broadcast[T, O, V]): Tensor[join.Out, V] = join.applyTo(t, other)(divide) + + def sqrt: Tensor[T, V] = Tensor(Jax.jnp.sqrt(t.jaxValue)) + def exp: Tensor[T, V] = Tensor(Jax.jnp.exp(t.jaxValue)) + def log: Tensor[T, V] = Tensor(Jax.jnp.log(t.jaxValue)) + def sin: Tensor[T, V] = Tensor(Jax.jnp.sin(t.jaxValue)) + def cos: Tensor[T, V] = Tensor(Jax.jnp.cos(t.jaxValue)) + def tanh: Tensor[T, V] = Tensor(Jax.jnp.tanh(t.jaxValue)) + + def approxEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor0[Bool] = approxElementEquals(other, tolerance).all + def approxElementEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor[T, Bool] = + Tensor( + Jax.jnp.allclose( + t.jaxValue, + other.jaxValue, + atol = tolerance, + rtol = tolerance + ) + ) + + // --------------------------------------------------------- + // IsBoolean operations + // --------------------------------------------------------- + + extension [T <: Tuple: Labels, V: IsBoolean](t: Tensor[T, V]) + + def all: Tensor0[V] = Tensor0(Jax.jnp.all(t.jaxValue)) + def any: Tensor0[V] = Tensor0(Jax.jnp.any(t.jaxValue)) + + def unary_! : Tensor[T, V] = Tensor(Jax.jnp.logical_not(t.jaxValue)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala new file mode 100644 index 00000000..39e0bff9 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala @@ -0,0 +1,153 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.Label +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Axis +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.ShapeTypeHelpers.SharedAxisRemover +import dimwit.OnError +import dimwit.tensor.LabelsImpl +import dimwit.tensor.ShapeTypeHelpers.AxisReplacer +import dimwit.tensor.tensorops.FunctionalOps.ZipVmap.TensorsOf + +object FunctionalOps: + // ----------------------------------------------------------- + // 5. Functional Operations (Higher Order) + // Lifting functions over axes + // ----------------------------------------------------------- + + object ZipVmap: + + type TensorsOf[Shapes <: Tuple, Values <: Tuple] <: Tuple = (Shapes, Values) match + case (EmptyTuple, EmptyTuple) => EmptyTuple + case ((shapeHead *: shapeTail), (valueHead *: valueTail)) => Tensor[shapeHead, valueHead] *: TensorsOf[shapeTail, valueTail] + + type ExtractShape[T] = T match + case Tensor[s, v] => s + + type ExtractValue[T] = T match + case Tensor[s, v] => v + + type ShapesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractShape] + type ValuesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractValue] + + def zipvmap[L: Label, Inputs <: Tuple, OutShape <: Tuple: Labels, OutV]( + axis: Axis[L] + )( + tensors: Inputs // This is a Tuple of Tensors + )(using + ev: SharedAxisRemover[ShapesOf[Inputs], L] + )( + f: TensorsOf[ev.RemainingAxes, ValuesOf[Inputs]] => Tensor[OutShape, OutV] + ): Tensor[L *: OutShape, OutV] = + val fpy = (args: py.Dynamic) => + OnError.traceStack: + val tensorList = args.as[Seq[py.Dynamic]].zip(ev.shapesLabels).map: (jaxArr, labels) => + Tensor(jaxArr)(using LabelsImpl(labels)) + + val inputTuple = Tuple.fromArray(tensorList.toArray) + val result = f(inputTuple.asInstanceOf[TensorsOf[ev.RemainingAxes, ValuesOf[Inputs]]]) + result.jaxValue + + val jaxInputs = py.Dynamic.global.tuple(tensors.toArray.map(_.asInstanceOf[Tensor[?, ?]].jaxValue).toPythonProxy) + val indicesAsTuple = py.Dynamic.global.tuple(ev.indices.toPythonProxy) + val jaxResult = Jax.jax_helper.zipvmap( + fpy, + indicesAsTuple + )(jaxInputs) + + Tensor(jaxResult) + + export ZipVmap.zipvmap + + extension [T <: Tuple: Labels, V](t: Tensor[T, V]) + + def vmap[VmapAxis: Label, OuterShape <: Tuple: Labels, V2]( + axis: Axis[VmapAxis] + )(using + ev: AxisRemover[T, VmapAxis] + )( + f: Tensor[ev.RemainingAxes, V] => Tensor[OuterShape, V2] + )(using + labels: Labels[ev.RemainingAxes] + ): Tensor[VmapAxis *: OuterShape, V2] = + val fpy = (jxpr: Jax.PyDynamic) => + OnError.traceStack: + val innerTensor = Tensor[ev.RemainingAxes, V](jxpr) + val result = f(innerTensor) + result.jaxValue + + Tensor(Jax.jax_helper.vmap(fpy, ev.index)(t.jaxValue)) + + def vapply[L: Label, NewL, R <: Tuple, NewV]( + axis: Axis[L] + )( + f: Tensor[Tuple1[L], V] => Tensor[Tuple1[NewL], NewV] + )(using + ev: AxisReplacer.Aux[T, L, NewL, R], + labels: Labels[R] + ): Tensor[R, NewV] = + val fpy = (jxpr: Jax.PyDynamic) => + OnError.traceStack: + val inputTensor = Tensor[Tuple1[L], V](jxpr) + val result = f(inputTensor) + result.jaxValue + + Tensor( + Jax.jnp.apply_along_axis( + fpy, + ev.index, + t.jaxValue + ) + ) + + def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])( + other: Tensor[T2, V] + )(using + ev: SharedAxisRemover[(T, T2), L] + )( + f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV] + ): Tensor[L *: OutShape, OutV] = + ZipVmap.zipvmap(axis)(t, other)(f) + + def vreduce[L: Label]( + axis: Axis[L] + )( + f: Tensor[Tuple1[L], V] => Tensor0[V] + )(using + ev: AxisRemover[T, L], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = + val fpy = (jxpr: Jax.PyDynamic) => + OnError.traceStack: + val inputTensor = Tensor[Tuple1[L], V](jxpr) + val result = f(inputTensor) + result.jaxValue + + Tensor( + Jax.jnp.apply_along_axis( + fpy, + ev.index, + t.jaxValue + ) + ) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala new file mode 100644 index 00000000..d6fb6104 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala @@ -0,0 +1,83 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.Label +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Axis +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.Tensor2 +import dimwit.tensor.Tensor1 + +object LinearAlgebraOps: + + extension [T <: Tuple: Labels, V](t: Tensor[T, V]) + + def diagonal[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes *: L1 *: EmptyTuple, V] = + Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + + extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) + + def diagonal: Tensor1[L1, V] = t.diagonal(0) + def diagonal(offset: Int): Tensor1[L1, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset)) + + // --------------------------------------------------------- + // IsNumber operations (IsFloat or IsInt) + // --------------------------------------------------------- + + extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) + + def trace[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + + extension [L1: Label, L2: Label, V: IsNumber](t: Tensor2[L1, L2, V]) + + def trace: Tensor0[V] = t.trace(0) + def trace(offset: Int): Tensor0[V] = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) + + // --------------------------------------------------------- + // IsFloat operations + // --------------------------------------------------------- + + extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) + + def norm: Tensor0[V] = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) + def inv: Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) + def det[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2])(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = + // JAX det only works on the last two axes (-2, -1). We must move the user's selected axes to the end. + val moved = Jax.jnp.moveaxis( + t.jaxValue, + source = ev.indices.toPythonProxy, + destination = Seq(-2, -1).toPythonProxy + ) + Tensor(Jax.jnp.linalg.det(moved)) + + extension [L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V]) + + def det: Tensor0[V] = Tensor0(Jax.jnp.linalg.det(t.jaxValue)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala new file mode 100644 index 00000000..0056abde --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala @@ -0,0 +1,90 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.Label +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Axis +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} + +object ReductionOps: + + // --------------------------------------------------------- + // IsNumber operations (IsFloat or IsInt) + // --------------------------------------------------------- + + extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) + + // --- Sum --- + def sum: Tensor0[V] = Tensor0(Jax.jnp.sum(t.jaxValue)) + def sum[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.index)) + def sum[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Max --- + def max: Tensor0[V] = Tensor0(Jax.jnp.max(t.jaxValue)) + def max[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.index)) + def max[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Min --- + def min: Tensor0[V] = Tensor0(Jax.jnp.min(t.jaxValue)) + def min[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.index)) + def min[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Argmax --- + def argmax: Tensor0[Int32] = Tensor0(Jax.jnp.argmax(t.jaxValue)) + def argmax[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) + def argmax[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Argmin --- + def argmin: Tensor0[Int32] = Tensor0(Jax.jnp.argmin(t.jaxValue)) + def argmin[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) + def argmin[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Argsort --- + def argsort: Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue)) + def argsort[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.index)) + def argsort[Inputs <: Tuple](axes: Inputs)(using ev: AxisIndices[T, UnwrapAxes[Inputs]]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --------------------------------------------------------- + // IsFloat operations (IsFloat or IsInt) + // --------------------------------------------------------- + + extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) + + // --- Mean --- + def mean: Tensor0[V] = Tensor0(Jax.jnp.mean(t.jaxValue)) + def mean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.index)) + def mean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Std --- + def std: Tensor0[V] = Tensor0(Jax.jnp.std(t.jaxValue)) + def std[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.index)) + def std[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Quantile --- + def quantile(q: Float): Tensor0[V] = Tensor0(Jax.jnp.quantile(t.jaxValue, q)) + def quantile[L: Label](q: Float, axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.index)) + def quantile[Inputs <: Tuple](q: Float, axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.indices.toPythonProxy)) + + // --- Median --- + def median: Tensor0[V] = Tensor0(Jax.jnp.median(t.jaxValue)) + def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) + def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala new file mode 100644 index 00000000..de0abcbd --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala @@ -0,0 +1,804 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor +import dimwit.tensor.Labels +import dimwit.jax.Jax +import dimwit.tensor.DType.Bool +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.VType +import dimwit.tensor.DType.Int32 +import dimwit.tensor.DType.Float32 +import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.Label +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Axis +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices +import dimwit.{`|*|`, `|+|`} + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.AxisAtIndex +import dimwit.tensor.AxisAtRange +import dimwit.tensor.AxisAtIndices +import dimwit.tensor.AxisAtTupleIndices +import dimwit.tensor.AxisAtTensorIndex +import scala.util.NotGiven +import scala.annotation.implicitNotFound +import dimwit.tensor.TupleHelpers +import dimwit.tensor.ShapeTypeHelpers.DimExtractor +import dimwit.tensor.AxisExtent +import dimwit.tensor.Tensor1 +import dimwit.tensor.ShapeTypeHelpers.MergeLabels +import dimwit.tensor.ShapeTypeHelpers.AxesMerger +import dimwit.tensor.Shape +import dimwit.tensor.ShapeTypeHelpers.AxisReplacerAll +import dimwit.tensor.TupleHelpers.TensorEvidence.IsPermutation +import dimwit.tensor.TupleHelpers.TensorEvidence.ValidationResult +import dimwit.tensor.ShapeTypeHelpers.AxesConditionalRemover +import dimwit.tensor.TupleHelpers.TensorEvidence.ComputeMissing +import dimwit.tensor.TupleHelpers.TensorEvidence.CheckValid +import dimwit.tensor.ShapeTypeHelpers.UnwrapDims + +import dimwit.jax.Einops +import dimwit.tensor.TupleHelpers.StrictSubset +import dimwit.tensor.ShapeTypeHelpers.AxisReplacer + +// ----------------------------------------------------------- +// 4. Structural Operations (Isomorphisms) +// Permutations and Views: T1 -> T2 (Size(T1) == Size(T2)) +// ----------------------------------------------------------- + +object StructuralOps: + + private object Util: + + type InsertBefore[T <: Tuple, A, B] <: Tuple = T match + case EmptyTuple => B *: EmptyTuple + case A *: tail => B *: A *: tail + case h *: tail => h *: InsertBefore[tail, A, B] + + type InsertAfter[T <: Tuple, A, B] <: Tuple = T match + case EmptyTuple => B *: EmptyTuple + case A *: tail => A *: B *: tail + case h *: tail => h *: InsertAfter[tail, A, B] + + type SliceIndex = Int | List[Int] | Range | Tensor0[Int32] + type ExtractLabel[X] = X match + case AxisAtIndex[l] => l + case AxisAtRange[l] => l + case AxisAtIndices[l] => l + case AxisAtTupleIndices[l, ?] => l + case AxisAtTensorIndex[l] => l + type ExtractLabels[Inputs <: Tuple] = Tuple.Map[Inputs, ExtractLabel] + + trait SliceLabelExtractor[Inputs <: Tuple, Out <: Tuple] + + object SliceLabelExtractor: + + given empty: SliceLabelExtractor[EmptyTuple, EmptyTuple] = + new SliceLabelExtractor[EmptyTuple, EmptyTuple] {} + + // New givens for AxisSelector types + given consAxisAtIndex[L, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[AxisAtIndex[L] *: Tail, L *: TailOut] = + new SliceLabelExtractor[AxisAtIndex[L] *: Tail, L *: TailOut] {} + + given consAxisAtRange[L, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[AxisAtRange[L] *: Tail, TailOut] = + new SliceLabelExtractor[AxisAtRange[L] *: Tail, TailOut] {} + + given consAxisAtIndices[L, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[AxisAtIndices[L] *: Tail, TailOut] = + new SliceLabelExtractor[AxisAtIndices[L] *: Tail, TailOut] {} + + given consAxisAtTupleIndices[L, I <: NonEmptyTuple, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[AxisAtTupleIndices[L, I] *: Tail, TailOut] = + new SliceLabelExtractor[AxisAtTupleIndices[L, I] *: Tail, TailOut] {} + + given consAxisAtTensorIndex[L, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[AxisAtTensorIndex[L] *: Tail, L *: TailOut] = + new SliceLabelExtractor[AxisAtTensorIndex[L] *: Tail, L *: TailOut] {} + + // Keep backward compatibility with tuple syntax + given consInt[L, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[(Axis[L], Int) *: Tail, L *: TailOut] = + new SliceLabelExtractor[(Axis[L], Int) *: Tail, L *: TailOut] {} + + given consTensor0Int[L, Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[(Axis[L], Tensor0[Int32]) *: Tail, L *: TailOut] = + new SliceLabelExtractor[(Axis[L], Tensor0[Int32]) *: Tail, L *: TailOut] {} + + given consSeq[L, SeqT <: Seq[Int], Tail <: Tuple, TailOut <: Tuple](using + tailExt: SliceLabelExtractor[Tail, TailOut] + ): SliceLabelExtractor[(Axis[L], SeqT) *: Tail, TailOut] = + new SliceLabelExtractor[(Axis[L], SeqT) *: Tail, TailOut] {} + + type Swap[T <: Tuple, A, B] <: Tuple = T match + case EmptyTuple => EmptyTuple + case A *: tail => B *: Swap[tail, A, B] + case B *: tail => A *: Swap[tail, A, B] + case h *: tail => h *: Swap[tail, A, B] + + @implicitNotFound("The axis ${L} is already present in the tensor shape ${T}.") + trait AxisAbsent[T, L] + object AxisAbsent: + given [T <: Tuple, L](using NotGiven[Tuple.Contains[T, L] =:= true]): AxisAbsent[T, L] = new AxisAbsent[T, L] {} + + import Util.* + + object TensorWhere: + def where[T <: Tuple: Labels, V]( + condition: Tensor[T, Bool], + x: Tensor[T, V], + y: Tensor[T, V] + ): Tensor[T, V] = + Tensor(Jax.jnp.where(condition.jaxValue, x.jaxValue, y.jaxValue)) + + export TensorWhere.where + + def triu[T <: Tuple: Labels, V](tensor: Tensor[T, V], kthDiagonal: Int = 0): Tensor[T, V] = + Tensor(Jax.jnp.triu(tensor.jaxValue, k = kthDiagonal)) + + def tril[T <: Tuple: Labels, V](tensor: Tensor[T, V], kthDiagonal: Int = 0): Tensor[T, V] = + Tensor(Jax.jnp.tril(tensor.jaxValue, k = kthDiagonal)) + + def stack[L: Label, T <: Tuple: Labels, V]( + tensors: Seq[Tensor[T, V]], + newAxis: Axis[L] + ): Tensor[L *: T, V] = + require(tensors.nonEmpty, "Cannot stack an empty sequence of tensors") + val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy + val stackedJaxValue = Jax.jnp.stack(jaxValuesSeq, axis = 0) + Tensor(stackedJaxValue) + + def stack[NewL, L, T <: Tuple: Labels, V]( + tensors: Seq[Tensor[T, V]], + newAxis: Axis[NewL], + afterAxis: Axis[L] + )(using + newLabel: Label[NewL], + axisIndex: AxisIndex[T, L] + ): Tensor[InsertAfter[T, L, NewL], V] = + require(tensors.nonEmpty, "Cannot stack an empty sequence of tensors") + val axisIdx = axisIndex.index + 1 // we are inserting after the given axis, so shift by 1 + val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy + val stackedJaxValue = Jax.jnp.stack(jaxValuesSeq, axis = axisIdx) + val names = summon[Labels[T]].names + val newNames = names.take(axisIdx) ++ Seq(newLabel.name) ++ names.drop(axisIdx) + given Labels[InsertAfter[T, L, NewL]] with + val names = newNames.toSeq + Tensor(stackedJaxValue) + + def concatenate[L: Label, T <: Tuple: Labels, V]( + tensors: Seq[Tensor[T, V]], + concatAxis: Axis[L] + )(using + axisIndex: AxisIndex[T, L] + ): Tensor[T, V] = + require(tensors.nonEmpty, "Cannot concatenate an empty sequence of tensors") + val axisIdx = axisIndex.index + val jaxValuesSeq = tensors.map(_.jaxValue).toPythonProxy + val concatenatedJaxValue = Jax.jnp.concatenate(jaxValuesSeq, axis = axisIdx) + Tensor(concatenatedJaxValue) + + def concatenate[L: Label, T <: Tuple: Labels, V]( + t1: Tensor[T, V], + t2: Tensor[T, V], + concatAxis: Axis[L] + )(using + axisIndex: AxisIndex[T, L] + ): Tensor[T, V] = concatenate(Seq(t1, t2), concatAxis) + + trait ValidConcat[T1 <: Tuple, T2 <: Tuple]: + type Out <: Tuple + def index: Int + + object ValidConcat: + type Aux[T1 <: Tuple, T2 <: Tuple, O <: Tuple] = ValidConcat[T1, T2] { type Out = O } + + given recursive[H, T1Tail <: Tuple, T2Tail <: Tuple, OutTail <: Tuple](using + next: ValidConcat.Aux[T1Tail, T2Tail, OutTail] + ): ValidConcat[H *: T1Tail, H *: T2Tail] with + type Out = H *: OutTail + def index: Int = next.index + 1 + + given concatAxis[H1, H2, Tail <: Tuple](using + isDifferent: NotGiven[H1 =:= H2] + ): ValidConcat[H1 *: Tail, H2 *: Tail] with + type Out = (H1 |+| H2) *: Tail + def index: Int = 0 + + def concatenate[T1 <: Tuple, T2 <: Tuple, V, R <: Tuple]( + t1: Tensor[T1, V], + t2: Tensor[T2, V] + )(using + canConcat: ValidConcat.Aux[T1, T2, R], + label: Labels[R] + ): Tensor[R, V] = + val jaxValues = List(t1.jaxValue, t2.jaxValue).toPythonProxy + Tensor(Jax.jnp.concatenate(jaxValues, axis = canConcat.index)) + + type SplitComponents[L, I <: Tuple] <: Tuple = I match + case EmptyTuple => L *: EmptyTuple + case _ *: tail => L *: SplitComponents[L, tail] + + trait Deconcatenator[L]: + type Components <: Tuple + def labels: List[Label[?]] + + object Deconcatenator extends DeconcatenatorLowPriority: + type Aux[L, C <: Tuple] = Deconcatenator[L] { type Components = C } + + given recursive[A, B, CA <: Tuple, CB <: Tuple](using + da: Aux[A, CA], + db: Aux[B, CB] + ): Aux[A |+| B, Tuple.Concat[CA, CB]] = + new Deconcatenator[A |+| B]: + type Components = Tuple.Concat[CA, CB] + def labels = da.labels ++ db.labels + + trait DeconcatenatorLowPriority: + given base[L](using l: Label[L]): Deconcatenator.Aux[L, L *: EmptyTuple] = + new Deconcatenator[L]: + type Components = L *: EmptyTuple + def labels = List(l) + + trait TensorTupleMaker[Components <: Tuple, FullShape <: Tuple, SplitAxis, V]: + type Out <: Tuple + def apply(arrays: Seq[Jax.PyDynamic], compLabels: List[Label[?]], originalLabels: Seq[String], splitIndex: Int): Out + + object TensorTupleMaker: + type Aux[C <: Tuple, F <: Tuple, S, V, O <: Tuple] = + TensorTupleMaker[C, F, S, V] { type Out = O } + + given empty[F <: Tuple, S, V]: Aux[EmptyTuple, F, S, V, EmptyTuple] = + new TensorTupleMaker[EmptyTuple, F, S, V]: + type Out = EmptyTuple + def apply(a: Seq[Jax.PyDynamic], c: List[Label[?]], o: Seq[String], i: Int) = EmptyTuple + + given cons[Head, Tail <: Tuple, F <: Tuple, S, V, NewShape <: Tuple](using + replacer: TupleHelpers.Replacer[F, S, Head] { type Out = NewShape }, + tailMaker: TensorTupleMaker[Tail, F, S, V] + ): Aux[Head *: Tail, F, S, V, Tensor[NewShape, V] *: tailMaker.Out] = + + new TensorTupleMaker[Head *: Tail, F, S, V]: + type Out = Tensor[NewShape, V] *: tailMaker.Out + + def apply(arrays: Seq[Jax.PyDynamic], compLabels: List[Label[?]], originalLabels: Seq[String], splitIndex: Int): Out = + val currentArr = arrays.head + val currentLabel = compLabels.head + val newNames = originalLabels.updated(splitIndex, currentLabel.name).toList + val newLabelsWitness = new Labels[NewShape]: + val names = newNames + val headTensor = Tensor[NewShape, V](currentArr)(using newLabelsWitness) + headTensor *: tailMaker(arrays.tail, compLabels.tail, originalLabels, splitIndex) + + extension [T <: Tuple, V](tensor: Tensor[T, V]) + + def deconcatenate[L, Dims <: Tuple, Comps <: Tuple, Result]( + axis: Axis[L], + dims: Dims + )(using + labels: Labels[T], + axisIndex: AxisIndex[T, L], + decon: Deconcatenator.Aux[L, Comps], + extractor: DimExtractor[Dims], + maker: TensorTupleMaker[Comps, T, L, V] + ): maker.Out = + val orderedSizes = dims.toList.asInstanceOf[List[Any]].map { + case ae: AxisExtent[?] => ae.size + case _ => throw new IllegalArgumentException("Invalid dims format - expected AxisExtent") + } + + require(orderedSizes.size == decon.labels.size, s"Provided ${orderedSizes.size} sizes but axis has ${decon.labels.size} components") + + val splitIndices = orderedSizes.scanLeft(0)(_ + _).tail.init + val pyIndices = me.shadaj.scalapy.py.Dynamic.global.list(splitIndices.toPythonProxy) + val splitArrays = Jax.jnp.split(tensor.jaxValue, pyIndices, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] + val originalNames = summon[Labels[T]].names.toSeq + + maker.apply(splitArrays, decon.labels, originalNames, axisIndex.index) + + /** Splits the tensor along the specified axis at the given indices, returning a tuple of tensors corresponding to the splits. + * + * @param selector of the form Axis[L].at((idx1, idx2, ...)) specifying the axis to split and the indices to split at + * @return the tuple of tensors resulting from the split + */ + def split[L: Label, I <: NonEmptyTuple](selector: AxisAtTupleIndices[L, I])(using + axisIndex: AxisIndex[T, L], + maker: TensorTupleMaker[SplitComponents[L, I], T, L, V], + labels: Labels[T] + ): maker.Out = + val splitList = selector.indices.toList.asInstanceOf[List[Int]] + val pyIndices = me.shadaj.scalapy.py.Dynamic.global.list(splitList.toPythonProxy) + val splitArrays = Jax.jnp.split(tensor.jaxValue, pyIndices, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] + val axisLabelInstance = summon[Label[L]] + val compLabels = List.fill(splitList.size + 1)(axisLabelInstance.asInstanceOf[Label[?]]) + maker.apply(splitArrays, compLabels, labels.names.toSeq, axisIndex.index) + + /** Splits the tensor along the specified axis at the given index, + * returning a tuple of two tensors corresponding to the splits. + * + * @param selector of the form Axis[L].at(idx) specifying the axis to split and the index to split at + * @return a tuple of two tensors resulting from the split + */ + def split[L: Label](selector: AxisAtIndex[L])(using + axisIndex: AxisIndex[T, L], + maker: TensorTupleMaker[L *: L *: EmptyTuple, T, L, V], + labels: Labels[T] + ): maker.Out = + split(AxisAtTupleIndices(selector.axis, Tuple1(selector.index))) + + private def calcPyIndices[Inputs <: Tuple]( + inputs: Inputs, + targetDims: List[Int] + ) = + + val PySlice = py.Dynamic.global.slice + val Colon = PySlice(py.None) + val rank = tensor.shape.rank + val indicesBuffer = collection.mutable.ArrayBuffer.fill[py.Any](rank)(Colon) + + val inputList = inputs.toList.asInstanceOf[List[Any]] + + targetDims.zip(inputList).foreach { case (dimIndex, input) => + val dimSize = tensor.shape.dimensions(dimIndex) + input match + // New AxisSelector types + case AxisAtIndex(_, idx) => + indicesBuffer(dimIndex) = py.Any.from(idx) + case AxisAtRange(_, range) => + indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) + case AxisAtIndices(_, indices) => + indicesBuffer(dimIndex) = indices.map(py.Any.from).toPythonCopy // TODO find out why Copy is needed here + case AxisAtTupleIndices(_, indices) => + indicesBuffer(dimIndex) = indices.toList.asInstanceOf[List[Int]].map(py.Any.from).toPythonCopy + case AxisAtTensorIndex(_, tensorIdx) => + indicesBuffer(dimIndex) = tensorIdx.jaxValue + // Backward compatibility with tuples + case (_, sliceIndex) => + sliceIndex match + case sliceSeq: List[Int] @unchecked => + indicesBuffer(dimIndex) = sliceSeq.map(py.Any.from).toPythonProxy + case range: Range @unchecked => + indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) + case idx: Int => + indicesBuffer(dimIndex) = py.Any.from(idx) + case tensorId: Tensor0[Int32] @unchecked => + indicesBuffer(dimIndex) = tensorId.jaxValue + } + + Jax.Dynamic.global.tuple(indicesBuffer.toSeq.toPythonProxy) + + /** Flattens all axes of the tensor into a single axis. + * The resulting tensor will have a single axis named by concatenating the original axis names with "*". + * + * @return a Tensor1 with the merged axis + */ + def flatten(using labels: Labels[T]): Tensor1[MergeLabels[T], V] = + given Labels[Tuple1[MergeLabels[T]]] with + def names = List(summon[Labels[T]].names.mkString("*")) + Tensor(Jax.jnp.ravel(tensor.jaxValue)) + + /** Flattens the specified axes of the tensor into a single axis. + * The resulting tensor will have the specified axes merged into a single axis named by concatenating the original axis names with "*" + * The other axes remain unchanged. + * + * @param axes the axes to flatten, specified as a tuple of Axis (e.g. (Axis[Ax1], Axis[Ax2])) + * @return a Tensor with the specified axes merged into a single axis + */ + def flatten[AxesTuple <: Tuple, R <: Tuple]( + axes: AxesTuple + )(using + merger: AxesMerger.Aux[T, UnwrapAxes[AxesTuple], R], + labels: Labels[R] + ): Tensor[R, V] = + val permuted = Jax.jnp.transpose(tensor.jaxValue, merger.permutation.toPythonProxy) + + val originalDims = tensor.shape.dimensions + val mergedSize = merger.mergeIndices.map(originalDims).product + + val remainingDims = originalDims.zipWithIndex + .filterNot((d, i) => merger.mergeIndices.contains(i)) + .map(_._1) + + val newDimensions = remainingDims.patch(merger.mergedIndex, Seq(mergedSize), 0) + + Tensor(Jax.jnp.reshape(permuted, newDimensions.toPythonProxy)) + + /** Unflattens splitAxis into a new shape specified by newShape. The other axes remain unchanged. + * + * The user must ensure that the size of splitAxis matches the product of the dimensions in newShape, otherwise a runtime error will occur. + * + * @param splitAxis the axis to unflatten + * @param newShape the new shape to unflatten into, specified as a Shape + * @return a Tensor with the specified axis unflattened into the new shape + */ + def unflatten[SplitL, NewT <: Tuple, R <: Tuple]( + splitAxis: Axis[SplitL], + newShape: Shape[NewT] + )(using + ev: AxisReplacerAll.Aux[T, SplitL, NewT, R], + labels: Labels[R] + ): Tensor[R, V] = + val before = tensor.shape.dimensions.take(ev.index) + val after = tensor.shape.dimensions.drop(ev.index + 1) + val fullNewShape = before ++ newShape.dimensions ++ after + Tensor( + Jax.jnp.reshape( + tensor.jaxValue, + py.Dynamic.global.tuple( + fullNewShape.map(py.Any.from).toPythonProxy + ) + ) + ) + + /** Unflattens the tensor into a new shape specified by newShape. + * + * The user must ensure that the size of the tensor matches the product of the dimensions in newShape, otherwise a runtime error will occur. + * + * @param newShape the new shape to unflatten into, specified as a Shape + * @return a Tensor with the new shape + */ + def unflatten[NewT <: Tuple: Labels]( + newShape: Shape[NewT] + )(using + @implicitNotFound("unflatten without axis can only be used on Tensor1 types.") + ev: T <:< Tuple1[Any] // <--- Ensures this only works on Tensor1 + ): Tensor[NewT, V] = + val fullNewShape = newShape.dimensions + Tensor( + Jax.jnp.reshape( + tensor.jaxValue, + py.Dynamic.global.tuple( + fullNewShape.map(py.Any.from).toPythonProxy + ) + ) + ) + + def transpose[NewOrder <: Tuple, Status <: ValidationResult](newOrder: NewOrder)(using + ev: AxisIndices[T, UnwrapAxes[NewOrder]], + newLabels: Labels[UnwrapAxes[NewOrder]] + )(using + allAxesEv: IsPermutation[T, UnwrapAxes[NewOrder]] + ): Tensor[UnwrapAxes[NewOrder], V] = + val indices = ev.indices + Tensor(Jax.jnp.transpose(tensor.jaxValue, indices.toPythonProxy)) + + /** Splits the tensor along the specified axis at the given indices, returning a sequence of tensors corresponding to the splits. + * + * @param unstackAxis the axis to split, specified as an Axis (e.g. Axis[Ax1]) + * @return a sequence of tensors resulting from the split, each with the specified axis removed + */ + def unstack[L: Label](unstackAxis: Axis[L])(using + labels: Labels[T], + ev: AxisRemover[T, L], + labelR: Labels[ev.RemainingAxes] + ): Seq[Tensor[ev.RemainingAxes, V]] = + val axisIdx = ev.index + val unstacked = Jax.jnp.split(tensor.jaxValue, tensor.shape.dimensions(axisIdx), axis = axisIdx).as[Seq[Jax.PyDynamic]] + unstacked.map(x => Tensor[ev.RemainingAxes, V](x)) + + def chunk[splitL: Label](splitAxis: Axis[splitL], chunkSize: Int)(using + labels: Labels[T], + axisIndex: AxisIndex[T, splitL] + ): Seq[Tensor[T, V]] = + val res = Jax.jnp.split(tensor.jaxValue, chunkSize, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] + res.map(x => Tensor[T, V](x)) + + def slice[Inputs <: Tuple, LabelsToRemove <: Tuple]( + inputs: Inputs + )(using + sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = + val pyIndices = tensor.calcPyIndices(inputs, ev.indices) + Tensor(tensor.jaxValue.bracketAccess(pyIndices)) + + // Convenience overload for AxisAtIndex + def slice[L, LabelsToRemove <: Tuple]( + selector: AxisAtIndex[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndex[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) + + // Convenience overload for AxisAtRange + def slice[L, LabelsToRemove <: Tuple]( + selector: AxisAtRange[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtRange[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) + + // Convenience overload for AxisAtIndices + def slice[L, LabelsToRemove <: Tuple]( + selector: AxisAtIndices[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndices[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) + + // Convenience overload for AxisAtTensorIndex + def slice[L, LabelsToRemove <: Tuple]( + selector: AxisAtTensorIndex[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTensorIndex[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) + + // Convenience overload for AxisAtTupleIndices + def slice[L, U <: NonEmptyTuple, LabelsToRemove <: Tuple]( + selector: AxisAtTupleIndices[L, U] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTupleIndices[L, U]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTupleIndices[L, U]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) + + def take[L1, L2: Label]( + axis: Axis[L1] + )( + indices: Tensor1[L2, Int32] + )(using + ev: AxisRemover[T, L1], + labels: Labels[ev.RemainingAxes] + ): Tensor[Tuple.Concat[Tuple1[L2], ev.RemainingAxes], V] = + val result = Jax.jnp.take(tensor.jaxValue, indices.jaxValue, axis = ev.index) + Tensor(result) + + def set[Inputs <: Tuple, LabelsToRemove <: Tuple]( + inputs: Inputs + )(using + sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs]], + labels: Labels[T] + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = + val pyIndices = tensor.calcPyIndices(inputs, ev.indices) + val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value.jaxValue) + Tensor(result) + + // Convenience overload for Float + def set[Inputs <: Tuple, LabelsToRemove <: Tuple]( + inputs: Inputs + )(using + sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], + ev: AxesConditionalRemover.Aux[T, LabelsToRemove, ExtractLabels[Inputs], EmptyTuple], + labels: Labels[T] + )(value: Float): Tensor[T, V] = + val pyIndices = tensor.calcPyIndices(inputs, ev.indices) + val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value) + Tensor(result) + + // Convenience overload for AxisAtIndex + def set[L, LabelsToRemove <: Tuple]( + selector: AxisAtIndex[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndex[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]]], + labels: Labels[T] + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) + + // Convenience overload for AxisAtRange + def set[L, LabelsToRemove <: Tuple]( + selector: AxisAtRange[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtRange[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]]], + labels: Labels[T] + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) + + // Convenience overload for AxisAtIndices + def set[L, LabelsToRemove <: Tuple]( + selector: AxisAtIndices[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndices[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]]], + labels: Labels[T] + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) + + // Convenience overload for AxisAtTensorIndex + def set[L, LabelsToRemove <: Tuple]( + selector: AxisAtTensorIndex[L] + )(using + sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTensorIndex[L]], LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]]], + labels: Labels[T] + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) + + def rearrange[Axes <: Tuple, Status <: ValidationResult](newOrder: Axes)(using + Labels[UnwrapAxes[Axes]] + )(using + computer: ComputeMissing[UnwrapAxes[Axes], T, EmptyTuple, Status], + guard: CheckValid[Status] + ): Tensor[UnwrapAxes[Axes], V] = + rearrange[Axes, EmptyTuple, Status](newOrder, EmptyTuple) + + // Convenience overload for 1 dims (to support error messages with single axis) + inline def rearrange[Axes <: Tuple, L1, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[Tuple1[AxisExtent[L1]]], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[Tuple1[AxisExtent[L1]]]): Tensor[UnwrapAxes[Axes], V] = + rearrange(newOrder, Tuple1(d1)) + + // Convenience overload for 2 dims + inline def rearrange[Axes <: Tuple, L1, L2, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1], d2: AxisExtent[L2])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[(AxisExtent[L1], AxisExtent[L2])], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[(AxisExtent[L1], AxisExtent[L2])]): Tensor[UnwrapAxes[Axes], V] = + rearrange(newOrder, (d1, d2)) + + // Convenience overload for 3 dims + inline def rearrange[Axes <: Tuple, L1, L2, L3, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1], d2: AxisExtent[L2], d3: AxisExtent[L3])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3])], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3])]): Tensor[UnwrapAxes[Axes], V] = + rearrange(newOrder, (d1, d2, d3)) + + // Convenience overload for 4 dims + inline def rearrange[Axes <: Tuple, L1, L2, L3, L4, Status <: ValidationResult](newOrder: Axes, d1: AxisExtent[L1], d2: AxisExtent[L2], d3: AxisExtent[L3], d4: AxisExtent[L4])(using computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3], AxisExtent[L4])], Status], guard: CheckValid[Status])(using newLabels: Labels[UnwrapAxes[Axes]], extractor: DimExtractor[(AxisExtent[L1], AxisExtent[L2], AxisExtent[L3], AxisExtent[L4])]): Tensor[UnwrapAxes[Axes], V] = + rearrange(newOrder, (d1, d2, d3, d4)) + + def rearrange[Axes <: Tuple, Dims <: Tuple, Status <: ValidationResult]( + newOrder: Axes, + dims: Dims + )(using + computer: ComputeMissing[UnwrapAxes[Axes], T, UnwrapDims[Dims], Status], + guard: CheckValid[Status] + )(using + newLabels: Labels[UnwrapAxes[Axes]], + extractor: DimExtractor[Dims] + ): Tensor[UnwrapAxes[Axes], V] = + def cleanPatternPrime(pattern: String): String = + // Support dimwit.Prime by replacing ' with "Prime" + pattern.replaceAll( + "'", + "Prime" + ) + def createEinopsPattern(fromPattern: String, toPattern: String): String = + def cleanPatternStar(pattern: String): String = + // to replace all a*b*c in pattern with (a b c), example: + // "a*b*c d e f*g h" -> "(a b c) d e (f g) h" + val regex = raw"([a-zA-Z0-9_]+(\*[a-zA-Z0-9_]+)+)".r + regex.replaceAllIn( + pattern, + _.group(1).split("\\*").mkString("(", " ", ")") + ) + def cleanPatternPlus(pattern: String): String = + // Support dimwit.|+| by replacing + with underlines + val regex = raw"([a-zA-Z0-9_]+(\+[a-zA-Z0-9_]+)+)".r + regex.replaceAllIn( + pattern, + _.group(1).replace("+", "_") + ) + def cleanPattern(pattern: String): String = + cleanPatternPlus(cleanPatternStar(cleanPatternPrime(pattern))) + s"${cleanPattern(fromPattern)} -> ${cleanPattern(toPattern)}" + val fromPattern = tensor.shape.labels.mkString(" ") + val toPattern = newLabels.names.mkString(" ") + val pattern = createEinopsPattern(fromPattern, toPattern) + val dimSizesMap = extractor.extract(dims) + val cleanDimSizesMap = dimSizesMap.map { case (k, v) => + val newKey = cleanPatternPrime(k) + (newKey, v) + } + Tensor( + Einops.rearrange( + tensor.jaxValue, + pattern, + kwargsMap = cleanDimSizesMap + ) + ) + + def broadcastTo[O <: Tuple: Labels](newShape: Shape[O])(using + labels: Labels[T], + ev: StrictSubset[T, O] + ): Tensor[O, V] = + /* Disallow implicit broadcasting where an *existing* axis changes size (implicitly). + * dimwit broadcasting only adds missing axes, never changes existing ones. + * + * This is a required check to prevent implicit broadcasting across dimwit. + * If this check is not explicitly present, Jax.jnp.broadcast_to would implicit broadcast.*/ + def disallowImplicitShapeBroadcasting(): Unit = + val tAxesDims = tensor.axes.zip(tensor.shape.dimensions).toMap + val newShapeAxesDims = newShape.labels.zip(newShape.dimensions).toMap + tensor.axes.foreach(axisName => + require( + tAxesDims(axisName) == newShapeAxesDims(axisName), + s"Broadcasting only adds missing axes. Present axes must have the same size. Axis ${axisName} has size ${tAxesDims(axisName)} in the current tensor but size ${newShapeAxesDims(axisName)} in the target shape." + ) + ) + + disallowImplicitShapeBroadcasting() // Make dimwit coders, good coders :) + + val t = tensor + + val currentNames = summon[Labels[T]].names + val targetNames = summon[Labels[O]].names + + val targetOrder = targetNames.filter(currentNames.contains) + val permutation = targetOrder.map(n => currentNames.indexOf(n)) + + val alignedJax = + if permutation != currentNames.indices.toList then Jax.jnp.transpose(t.jaxValue, permutation.toPythonProxy) + else t.jaxValue + + val currentShapeMap = currentNames.zip(t.shape.dimensions).toMap + + val intermediateShape = targetNames.map { name => + currentShapeMap.getOrElse(name, 1) + } + + val reshapedJax = Jax.jnp.reshape(alignedJax, intermediateShape.toPythonProxy) + Tensor(Jax.jnp.broadcast_to(reshapedJax, newShape.dimensions.toPythonProxy)) + + def relabel[OldLabel: Label, NewLabel: Label]( + rename: (Axis[OldLabel], Axis[NewLabel]) + )(using + ev: AxisReplacer[T, OldLabel, NewLabel], + newLabels: Labels[ev.NewShape] + ): Tensor[ev.NewShape, V] = Tensor(tensor.jaxValue) + + def retag[newT <: Tuple](using newLabels: Labels[newT]): Tensor[newT, V] = + Tensor(tensor.jaxValue)(using newLabels) + + def relabelAll[newT <: Tuple]( + newAxes: newT + )(using + newLabels: Labels[UnwrapAxes[newT]], + @implicitNotFound("Cannot convert tensor of shape ${T} to shape ${newT} due to size mismatch.") + evSameSize: Tuple.Size[newT] =:= Tuple.Size[T] + ): Tensor[UnwrapAxes[newT], V] = Tensor[UnwrapAxes[newT], V](tensor.jaxValue) + + def swap[L1: Label, L2: Label]( + axis1: Axis[L1], + axis2: Axis[L2] + )(using + labels: Labels[T], + axisIndex1: AxisIndex[T, L1], + axisIndex2: AxisIndex[T, L2] + ): Tensor[Swap[T, L1, L2], V] = + given Labels[Swap[T, L1, L2]] with + def names = + val originalNames = summon[Labels[T]].names + val ax1Name = summon[Label[L1]].name + val ax2Name = summon[Label[L2]].name + originalNames.map { + case n if n == ax1Name => ax2Name + case n if n == ax2Name => ax1Name + case n => n + } + Tensor(Jax.jnp.swapaxes(tensor.jaxValue, axisIndex1.index, axisIndex2.index)) + + def appendAxis[L: Label](axis: Axis[L])(using labels: Labels[T], ev: AxisAbsent[T, L]): Tensor[Tuple.Concat[T, Tuple1[L]], V] = + val newShape = tensor.shape.dimensions :+ 1 + Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) + + def prependAxis[L: Label](axis: Axis[L])(using labels: Labels[T], ev: AxisAbsent[T, L]): Tensor[Tuple.Concat[Tuple1[L], T], V] = + val newShape = 1 +: tensor.shape.dimensions + Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) + + def squeeze[L: Label](axis: Axis[L])(using + ev: AxisRemover[T, L], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = + require( + tensor.shape.dimensions(ev.index) == 1, + s"Cannot squeeze axis ${summon[Label[L]].name} of size ${tensor.shape.dimensions(ev.index)}" + ) + Tensor(Jax.jnp.squeeze(tensor.jaxValue, axis = ev.index)) + + extension [L: Label, V](tensor: Tensor1[L, V]) + def roll(shift: Int): Tensor1[L, V] = + Tensor(Jax.jnp.roll(tensor.jaxValue, shift = shift, axis = 0)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala new file mode 100644 index 00000000..5f8d3d0d --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala @@ -0,0 +1,61 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor0 +import dimwit.tensor.DType.* +import dimwit.tensor.TensorOps +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.Labels +import dimwit.tensor.tensorops.ElementWiseOps.add +import dimwit.tensor.Tensor +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.ElementWiseOps.divide +import dimwit.tensor.TensorOps.IsFloating + +object Tensor0Ops: + + private inline def checkTracer[V, R](scalar: Tensor0[V]): Unit = + require( + !scalar.isTracer, + """ + | Cannot convert a JAX Tracer to a scalar value. Tensor0 is part of a JAX computation graph (e.g., inside vmap or a jitted function). + | Common mistakes leading to this error: + | - calling .slice(t0.item) rather than .slice(t0); breaking the computation graph unintentionally. + |""".stripMargin + ) + + extension (scalar: Tensor0[Bool]) + def item: Boolean = + checkTracer(scalar) + scalar.jaxValue.item().as[Boolean] + + extension (scalar: Tensor0[Int8]) + def item: Byte = + checkTracer(scalar) + scalar.jaxValue.item().as[Byte] + + extension (scalar: Tensor0[Int16]) + def item: Short = + checkTracer(scalar) + scalar.jaxValue.item().as[Int].toShort + + extension (scalar: Tensor0[Int32]) + def item: Int = + checkTracer(scalar) + scalar.jaxValue.item().as[Int] + + extension (scalar: Tensor0[Int64]) + def item: Long = + checkTracer(scalar) + scalar.jaxValue.item().as[Long] + + extension (scalar: Tensor0[Float32]) + def item: Float = + checkTracer(scalar) + scalar.jaxValue.item().as[Float] + + extension (scalar: Tensor0[Float64]) + def item: Double = + checkTracer(scalar) + scalar.jaxValue.item().as[Double] diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala new file mode 100644 index 00000000..bef2a6c3 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala @@ -0,0 +1,49 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor0 +import dimwit.tensor.DType.* +import dimwit.tensor.TensorOps +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.Labels +import dimwit.tensor.tensorops.ElementWiseOps.add +import dimwit.tensor.Tensor +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.ElementWiseOps.divide +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.Axis +import dimwit.tensor.Label +import dimwit.tensor.Tensor1 +import dimwit.tensor.HasScalar +import dimwit.jax.Jax + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} + +object Tensor1Ops: + + extension [L, V](t: Tensor1[L, V]) + + def relabelTo[NewL: Label](newAxis: Axis[NewL]): Tensor1[NewL, V] = Tensor[Tuple1[NewL], V](t.jaxValue) + + // TODO generalize to TensorN (like slice) + def dynamicSlice( + dynamicStart: Tensor0[Int32], + staticSize: Int + )(using + label: Label[L] + ): Tensor1[L, V] = + // TODO understand why toPythonCopy is needed and toPythonProxy fails! + Tensor(Jax.lax.dynamic_slice(t.jaxValue, Seq(dynamicStart.jaxValue).toPythonCopy, Seq(staticSize).toPythonCopy)) + + extension [L, V, X](t: Tensor1[L, V])(using ev: HasScalar[V, X]) + /** Converts a Tensor1 to a Scala Array. + * The user must ensure that the tensor is not a JAX Tracer + * (i.e., it is not part of a JAX computation graph) before calling this method, + * otherwise a runtime error will occur. + */ + def toArray: Array[X] = + require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") + ev.readFlat(t.jaxValue) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala new file mode 100644 index 00000000..c33e94f8 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala @@ -0,0 +1,43 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor0 +import dimwit.tensor.DType.* +import dimwit.tensor.TensorOps +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.Labels +import dimwit.tensor.tensorops.ElementWiseOps.add +import dimwit.tensor.Tensor +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.ElementWiseOps.divide +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.Axis +import dimwit.tensor.Label +import dimwit.tensor.Tensor1 +import dimwit.tensor.HasScalar +import dimwit.jax.Jax + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.Tensor2 + +object Tensor2Ops: + + extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) + + // Support .transpose without arguments for 2D tensors while keeping (not shadowing) the general .transpose with arguments + def transpose: Tensor2[L2, L1, V] = t.transpose(Axis[L2], Axis[L1]) + def transpose(axis2: Axis[L2], axis1: Axis[L1]): Tensor2[L2, L1, V] = StructuralOps.transpose(t)(axis2, axis1) + + extension [L1, L2, V, X](t: Tensor2[L1, L2, V])(using ev: HasScalar[V, X]) + /** Converts a Tensor2 to a nested Scala Array (Array of Arrays). + * The user must ensure that the tensor is not a JAX Tracer + * (i.e., it is not part of a JAX computation graph) before calling this method, + * otherwise a runtime error will occur. + */ + def toArray: Array[Array[X]] = + require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") + given scala.reflect.ClassTag[X] = ev.classTag + ev.readFlat(t.jaxValue).grouped(t.shape.dimensions(1)).toArray diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala new file mode 100644 index 00000000..222afab6 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala @@ -0,0 +1,39 @@ +package dimwit.tensor.tensorops + +import dimwit.tensor.Tensor0 +import dimwit.tensor.DType.* +import dimwit.tensor.TensorOps +import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.Labels +import dimwit.tensor.tensorops.ElementWiseOps.add +import dimwit.tensor.Tensor +import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast +import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.ElementWiseOps.divide +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.Axis +import dimwit.tensor.Label +import dimwit.tensor.Tensor1 +import dimwit.tensor.HasScalar +import dimwit.jax.Jax + +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.Tensor2 +import dimwit.tensor.Tensor3 + +object Tensor3Ops: + + extension [L1, L2, L3, V, X](t: Tensor3[L1, L2, L3, V])(using ev: HasScalar[V, X]) + /** Converts a Tensor3 to a nested Scala Array (Array of Arrays of Arrays). + * The user must ensure that the tensor is not a JAX Tracer + * (i.e., it is not part of a JAX computation graph) before calling this method, + * otherwise a runtime error will occur. + */ + def toArray: Array[Array[Array[X]]] = + require(!t.isTracer, "Cannot convert a JAX Tracer to an array.") + given scala.reflect.ClassTag[X] = ev.classTag + val d1 = t.shape.dimensions(1); val d2 = t.shape.dimensions(2) + ev.readFlat(t.jaxValue).grouped(d1 * d2).map(_.grouped(d2).toArray).toArray diff --git a/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala b/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala new file mode 100644 index 00000000..fa9b577b --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala @@ -0,0 +1,40 @@ +package dimwit.tensor.tensorops + +import scala.annotation.implicitNotFound +import dimwit.tensor.Labels +import dimwit.tensor.Tensor +import dimwit.tensor.TupleHelpers.StrictSubset + +object TensorOpsUtil: + + import dimwit.tensor.TensorOps.broadcastTo + + @implicitNotFound("Cannot broadcast tensors of shapes ${T1} and ${T2}. If same shape no broadcasting allowed!") + sealed trait Broadcast[T1 <: Tuple, T2 <: Tuple, V]: + type Out <: Tuple + given labelsOut: Labels[Out] + def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]): (Tensor[Out, V], Tensor[Out, V]) + def applyTo[V2](t1: Tensor[T1, V], t2: Tensor[T2, V])(f: (Tensor[Out, V], Tensor[Out, V]) => Tensor[Out, V2]): Tensor[Out, V2] = + val (bt1, bt2) = broadcast(t1, t2) + f(bt1, bt2) + + object Broadcast extends BroadcastLowPriority: + + given broadcastLeft[T1 <: Tuple: Labels, T2 <: Tuple: Labels, V](using + StrictSubset[T2, T1] + ): Broadcast[T1, T2, V] with + type Out = T1 + val labelsOut = summon[Labels[T1]] + def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]) = + (t1, t2.broadcastTo[T1](t1.shape)) + + trait BroadcastLowPriority: + given broadcastRight[T1 <: Tuple: Labels, T2 <: Tuple: Labels, V](using + StrictSubset[T1, T2] + ): Broadcast[T1, T2, V] with + type Out = T2 + val labelsOut = summon[Labels[T2]] + def broadcast(t1: Tensor[T1, V], t2: Tensor[T2, V]) = + (t1.broadcastTo[T2](t2.shape), t2) + +end TensorOpsUtil diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala index 26a25f58..ce9f7e95 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala @@ -1,7 +1,7 @@ package dimwit.tensor import dimwit.* -import dimwit.tensor.TensorOps.Convolution.Padding +import dimwit.tensor.TensorOps.Padding import dimwit.stats.Normal class TensorOpsConvolutionSuite extends DimwitTest: From 899261f9268bf7780be15341d16b3a42c61c4d77 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Jun 2026 19:07:08 +0200 Subject: [PATCH 2/4] remove valueops export from tensorops --- core/src/main/scala/dimwit/package.scala | 1 + .../src/main/scala/dimwit/tensor/TensorOps.scala | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 3a9fef85..168f26b8 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -73,6 +73,7 @@ package object dimwit: // Export operations export dimwit.tensor.TensorOps.* + export dimwit.tensor.ValueOps.* // Export devices export dimwit.hardware.Device diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index a9cc74c1..1bdbe260 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -5,15 +5,19 @@ import dimwit.DType.given import dimwit.OnError import dimwit.jax.Jax import dimwit.tensor.HasScalar -import dimwit.tensor.{Label, Labels} +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.ShapeTypeHelpers.* -import dimwit.tensor.TensorOps.ZipVmap.{ShapesOf, TensorsOf} +import dimwit.tensor.TensorOps.ZipVmap.ShapesOf +import dimwit.tensor.TensorOps.ZipVmap.TensorsOf import dimwit.tensor.TupleHelpers.* -import dimwit.{`|*|`, `|+|`} - +import dimwit.tensor.tensorops.StructuralOps +import dimwit.|*| +import dimwit.|+| import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import me.shadaj.scalapy.readwrite.Reader +import me.shadaj.scalapy.readwrite.Writer import scala.annotation.implicitNotFound import scala.annotation.targetName @@ -22,7 +26,6 @@ import scala.util.NotGiven import Tuple.:* import Tuple.++ -import dimwit.tensor.tensorops.StructuralOps object TensorOps: @@ -77,7 +80,6 @@ object TensorOps: export tensorops.FunctionalOps.* export tensorops.Tensor0Ops.* - export ValueOps.* export tensorops.Tensor1Ops.* export tensorops.Tensor2Ops.* export tensorops.Tensor3Ops.* From 614bd139085ca33c4e924516300f5fba9929a79e Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Jun 2026 20:56:34 +0200 Subject: [PATCH 3/4] add back namedian and namean --- .../main/scala/dimwit/tensor/tensorops/ReductionOps.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala index 0056abde..a529e2d4 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala @@ -88,3 +88,11 @@ object ReductionOps: def median: Tensor0[V] = Tensor0(Jax.jnp.median(t.jaxValue)) def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) + + def nanmean: Tensor0[V] = Tensor0(Jax.jnp.nanmean(t.jaxValue)) + def nanmean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.index)) + def nanmean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.indices.toPythonProxy)) + + def nanmedian: Tensor0[V] = Tensor0(Jax.jnp.nanmedian(t.jaxValue)) + def nanmedian[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.index)) + def nanmedian[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.indices.toPythonProxy)) From 4ab8a6a3e39963e60d6a0504ba5cdf5cea373c08 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Jun 2026 21:16:05 +0200 Subject: [PATCH 4/4] recreate agents file --- AGENTS.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index fe9dd134..312be6a6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1166,9 +1166,9 @@ val wrong = t1 +! t2 // Cannot broadcast tensors of shapes Tuple1[MdocApp11.this.A] and Tuple1[MdocApp11.this.A]. If same shape no broadcasting allowed!. // I found: // -// dimwit.tensor.TensorOpsUtil.Broadcast.broadcastLeft[Tuple1[MdocApp11.this.A], -// Tuple1[MdocApp11.this.A], (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] -// ( +// dimwit.tensor.tensorops.TensorOpsUtil.Broadcast.broadcastLeft[ +// Tuple1[MdocApp11.this.A], Tuple1[MdocApp11.this.A], +// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)]( // dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type]( // this.A.derived$Label, dimwit.tensor.Labels.namesOfEmpty), // dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type](