diff --git a/AGENTS.md b/AGENTS.md index 01369f6..fe9dd13 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -103,7 +103,7 @@ val wrong = t.sum(Axis[C]) // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ // (repl.MdocSession.MdocApp.A, repl.MdocSession.MdocApp.B), -// repl.MdocSession.MdocApp.C, Tuple]( +// repl.MdocSession.MdocApp.C, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[repl.MdocSession.MdocApp.A, // repl.MdocSession.MdocApp.B *: EmptyTuple.type, repl.MdocSession.MdocApp.C]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[repl.MdocSession.MdocApp.B, @@ -208,7 +208,7 @@ val summed = wrongAxis.sum(Axis[B]) // B not in shape! // I found: // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ -// Tuple1[repl.MdocSession.MdocApp.A], repl.MdocSession.MdocApp.B, Tuple]( +// Tuple1[repl.MdocSession.MdocApp.A], repl.MdocSession.MdocApp.B, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[repl.MdocSession.MdocApp.A, // EmptyTuple.type, repl.MdocSession.MdocApp.B]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.concatRight[A², B², L]), @@ -315,7 +315,7 @@ val wrong = t.sum(Axis[C]) // I found: // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ -// (MdocApp0.this.A, MdocApp0.this.B), MdocApp0.this.C, Tuple]( +// (MdocApp0.this.A, MdocApp0.this.B), MdocApp0.this.C, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp0.this.A, // MdocApp0.this.B *: EmptyTuple.type, MdocApp0.this.C]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp0.this.B, @@ -436,7 +436,7 @@ val wrong = m1.dot(Axis[B])(m2) // I found: // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ -// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, Tuple]( +// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.C, // MdocApp1.this.D *: EmptyTuple.type, MdocApp1.this.B]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.D, @@ -455,7 +455,7 @@ val wrong = m1.dot(Axis[B])(m2) // I found: // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ -// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, Tuple]( +// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.C, // MdocApp1.this.D *: EmptyTuple.type, MdocApp1.this.B]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.D, @@ -1128,7 +1128,7 @@ val wrong = m1.dot(Axis[B])(m2) // Axis[B] not in m2 // I found: // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ -// (MdocApp11.this.C, MdocApp11.this.D), MdocApp11.this.B, Tuple]( +// (MdocApp11.this.C, MdocApp11.this.D), MdocApp11.this.B, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.C, // MdocApp11.this.D *: EmptyTuple.type, MdocApp11.this.B]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.D, @@ -1204,7 +1204,7 @@ val wrong = t.sum(Axis[C]) // Axis[C] not in tensor // I found: // // dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ -// (MdocApp11.this.A, MdocApp11.this.B), MdocApp11.this.C, Tuple]( +// (MdocApp11.this.A, MdocApp11.this.B), MdocApp11.this.C, R]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.A, // MdocApp11.this.B *: EmptyTuple.type, MdocApp11.this.C]( // dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.B, diff --git a/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala b/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala index 27abb2c..9093a66 100644 --- a/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala +++ b/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala @@ -9,23 +9,28 @@ object ShapeTypeHelpers: import TupleHelpers.* + /** Wraps each element of a tuple in an Axis */ type WrapAxes[T <: Tuple] <: Tuple = T match case EmptyTuple => EmptyTuple case a *: tail => Axis[a] *: WrapAxes[tail] + /** Unwraps each Axis in a tuple to get the label types */ type UnwrapAxes[T <: Tuple] <: Tuple = T match case EmptyTuple => EmptyTuple case Axis[a] *: tail => a *: UnwrapAxes[tail] case h *: tail => h *: UnwrapAxes[tail] + /** Unwrap each AxisExtent in a tuple to get the label types */ type UnwrapDims[T <: Tuple] <: Tuple = T match case EmptyTuple => EmptyTuple case AxisExtent[a] *: tail => a *: UnwrapDims[tail] + /** Base trait for tracking an axis in a tensor shape */ @implicitNotFound("Axis[${Axis}] not found in Tensor[${TensorShape}]") trait AxisInTensor[TensorShape <: Tuple, Axis]: def index: Int + /** Finds the index of an axis in a tensor shape */ trait AxisIndex[Shape <: Tuple, Axis] extends AxisInTensor[Shape, Axis] object AxisIndex: @@ -49,16 +54,27 @@ object ShapeTypeHelpers: given concatEnd[A <: Tuple, L]: AxisIndex[Tuple.Concat[A, Tuple1[L]], L] with val index = -1 - trait AxisRemover[TensorShape <: Tuple, Axis, RemainingShape <: Tuple] extends AxisInTensor[TensorShape, Axis] + /** Removing an axis from a tensor shape. + * + * RemainingAxes is the resulting shape after removing the axis. + */ + trait AxisRemover[TensorShape <: Tuple, Axis] extends AxisInTensor[TensorShape, Axis]: + type RemainingAxes <: Tuple object AxisRemover: + type Aux[S <: Tuple, A, R <: Tuple] = AxisRemover[S, A] { type RemainingAxes = R } + given bridge[S <: Tuple, A, R <: Tuple](using axisIndex: AxisIndex[S, A], ev: RemoverAll.Aux[S, A *: EmptyTuple, R] - ): AxisRemover[S, A, R] with + ): AxisRemover.Aux[S, A, R] = new AxisRemover[S, A]: + type RemainingAxes = R def index: Int = axisIndex.index - // Replace single axis with single axis + /** Replaces an axis in a tensor shape with another axis. + * + * NewShape is the resulting shape after replacement. + */ trait AxisReplacer[TensorShape <: Tuple, Axis, AxisReplacement] extends AxisInTensor[TensorShape, Axis]: type NewShape <: Tuple @@ -72,7 +88,7 @@ object ShapeTypeHelpers: def index: Int = idx.index type NewShape = O - // Replace single axis with multiple axes + /** Replace Axis in given Tuple with the Axes in the AxisReplacements tuple */ trait AxisReplacerAll[TensorShape <: Tuple, Axis, AxisReplacements <: Tuple] extends AxisInTensor[TensorShape, Axis]: type NewShape <: Tuple @@ -105,10 +121,12 @@ object ShapeTypeHelpers: type NewShape = O def index: Int = s.index + /** Base trait for tracking multiple axes in a tensor shape */ @implicitNotFound("Axes [${Axes}] not all found in Tensor shape [${TensorShape}]") trait AxesInTensor[TensorShape <: Tuple, Axes <: Tuple]: def indices: List[Int] + /** Finds the indices of multiple axes in a tensor shape */ sealed trait AxisIndices[T <: Tuple, Axes <: Tuple] extends AxesInTensor[T, Axes] object AxisIndices: @@ -124,44 +142,65 @@ object ShapeTypeHelpers: inline given [T <: Tuple, ToFind <: Tuple]: AxisIndices[T, ToFind] = AxisIndicesImpl[T, ToFind](indicesOfList[T, ToFind]) end AxisIndices - trait AxesRemover[TensorShape <: Tuple, Axes <: Tuple, RemainingShape <: Tuple] extends AxesInTensor[TensorShape, Axes] + + /** Removes multiple axes from a tensor shape. */ + trait AxesRemover[TensorShape <: Tuple, Axes <: Tuple] extends AxesInTensor[TensorShape, Axes]: + type RemainingAxes <: Tuple object AxesRemover: + type Aux[T <: Tuple, Axes <: Tuple, R <: Tuple] = AxesRemover[T, Axes] { type RemainingAxes = R } + given bridge[T <: Tuple, Axes <: Tuple, R <: Tuple](using idx: AxisIndices[T, Axes], ev: RemoverAll.Aux[T, Axes, R] - ): AxesRemover[T, Axes, R] with + ): AxesRemover.Aux[T, Axes, R] = new AxesRemover[T, Axes]: + type RemainingAxes = R def indices: List[Int] = idx.indices - trait AxesConditionalRemover[TensorShape <: Tuple, RemovedAxis <: Tuple, IndexAxes <: Tuple, RemainingShape <: Tuple] extends AxesInTensor[TensorShape, IndexAxes] + /** Removes [[RemovedAxis]] from a tensor shape while computing runtime indices + * for [[IndexAxes]]. + */ + trait AxesConditionalRemover[TensorShape <: Tuple, RemovedAxis <: Tuple, IndexAxes <: Tuple] extends AxesInTensor[TensorShape, IndexAxes]: + type RemainingAxes <: Tuple object AxesConditionalRemover: + type Aux[T <: Tuple, RA <: Tuple, IA <: Tuple, R <: Tuple] = AxesConditionalRemover[T, RA, IA] { type RemainingAxes = R } + given bridge[T <: Tuple, RemovedAxis <: Tuple, IndexAxes <: Tuple, R <: Tuple](using idx: AxisIndices[T, IndexAxes], ev: RemoverAll.Aux[T, RemovedAxis, R] - ): AxesConditionalRemover[T, RemovedAxis, IndexAxes, R] with + ): AxesConditionalRemover.Aux[T, RemovedAxis, IndexAxes, R] = new AxesConditionalRemover[T, RemovedAxis, IndexAxes]: + type RemainingAxes = R def indices = idx.indices + /** Removes a shared axis from multiple tensor shapes while computing runtime indices + * for the remaining axes. + */ @implicitNotFound("Axis[${Axis}] not found in ${Shapes}}") - trait SharedAxisRemover[Shapes <: Tuple, Axis, Sliced <: Tuple]: + trait SharedAxisRemover[Shapes <: Tuple, Axis]: + type RemainingAxes <: Tuple def indices: List[Int] def shapesLabels: List[List[String]] object SharedAxisRemover: + type Aux[S <: Tuple, A, O <: Tuple] = SharedAxisRemover[S, A] { type RemainingAxes = O } - given empty[Axis]: SharedAxisRemover[EmptyTuple, Axis, EmptyTuple] with + given empty[Axis]: SharedAxisRemover.Aux[EmptyTuple, Axis, EmptyTuple] = new SharedAxisRemover[EmptyTuple, Axis]: + type RemainingAxes = EmptyTuple def indices = Nil def shapesLabels = Nil - type Sliced = EmptyTuple given cons[H <: Tuple, T <: Tuple, Axis, R <: Tuple, TailOut <: Tuple](using - evH: AxisRemover[H, Axis, R], - evT: SharedAxisRemover[T, Axis, TailOut], + evH: AxisRemover.Aux[H, Axis, R], + evT: SharedAxisRemover.Aux[T, Axis, TailOut], rLabels: Labels[R] - ): SharedAxisRemover[H *: T, Axis, R *: TailOut] with + ): SharedAxisRemover.Aux[H *: T, Axis, R *: TailOut] = new SharedAxisRemover[H *: T, Axis]: + type RemainingAxes = R *: TailOut def indices = evH.index :: evT.indices def shapesLabels = List(rLabels.names) ++ evT.shapesLabels + /** Extracts the dimensions of a tensor shape into a Map of label names to sizes. + */ trait DimExtractor[T]: def extract(t: T): Map[String, Int] @@ -181,6 +220,10 @@ object ShapeTypeHelpers: def extract(t: AxisExtent[L]) = Map(label.name -> t.size) + /** Merges multiple axes in a tensor shape into a single axis. + * + * NewShape is the resulting shape after merging the axes. + */ @implicitNotFound("Cannot merge axes ${ToMerge} in shape ${S}. Ensure all axes exist.") trait AxesMerger[S <: Tuple, ToMerge <: Tuple]: type NewShape <: Tuple diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index 58e9118..ca9dfa0 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -21,19 +21,34 @@ import me.shadaj.scalapy.readwrite.Writer.stringWriter.given import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} import DType.* +/** A tensor with a fixed shape and data type. + * + * @param T The shape of the tensor, represented as a tuple of axis labels. + * @param V The data type of the tensor elements. + */ class Tensor[T <: Tuple: Labels, V] private[dimwit] ( private[dimwit] val jaxValue: Jax.PyDynamic ): + /** The labels of the tensor's axes. */ lazy val axes: List[String] = shape.labels - lazy val dtype: DType = JaxDType.fromJaxDtype(jaxValue.dtype) + + /** The shape of the tensor. */ lazy val shape: Shape[T] = Shape.fromSeq[T](jaxValue.shape.as[Seq[Int]]) + + /** The data type of the underlying Jax tensor. */ + lazy val dtype: DType = JaxDType.fromJaxDtype(jaxValue.dtype) + + /** The value type of the tensor (static type information) */ lazy val vtype: VType[V] = VType(this) + /** The device on which the tensor is stored. */ lazy val device: Device = Device(jaxValue.device) + /** Converts the tensor to a different value type, if compatible. */ def asType[V2](vtype: VType[V2]): Tensor[T, V2] = new Tensor(Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(vtype.dtype))) + /** Moves the tensor to a different device. */ def toDevice(newDevice: Device): Tensor[T, V] = new Tensor(jaxValue = Jax.device_put(jaxValue, newDevice.toJaxDevice)) override def equals(other: Any): Boolean = @@ -51,6 +66,7 @@ class Tensor[T <: Tuple: Labels, V] private[dimwit] ( s"TracerTensor(${shape.toString})" case _ => jaxValue.toString() + /** Returns the @AxisExtent of the specified axis in the tensor's shape. */ def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = shape.extent(axis) @@ -63,25 +79,36 @@ object Tensor: type IndicesOf[T <: Tuple] = Tuple.Map[T, [_] =>> Int] + /** Factory for createing tensors with a specific shape and default value type. + * + * @param shape The shape of the tensor to create. + */ case class DefaultsFactory[T <: Tuple: Labels](shape: Shape[T]): + /** Ceates a tensor filled with the specified value. */ def fill(value: Float): Tensor[T, Float32] = Tensor(shape, VType[Float32]).fill(value) def fill(value: Double): Tensor[T, Float64] = Tensor(shape, VType[Float64]).fill(value) - def fromArray(values: Array[Float]): Tensor[T, Float32] = Tensor(shape, VType[Float32]).fromArray(values) - def fromArray(values: Array[Double]): Tensor[T, Float64] = Tensor(shape, VType[Float64]).fromArray(values) - def fill(value: Byte): Tensor[T, Int8] = Tensor(shape, VType[Int8]).fill(value) def fill(value: Short): Tensor[T, Int16] = Tensor(shape, VType[Int16]).fill(value) def fill(value: Int): Tensor[T, Int32] = Tensor(shape, VType[Int32]).fill(value) def fill(value: Long): Tensor[T, Int64] = Tensor(shape, VType[Int64]).fill(value) + def fill(value: Boolean): Tensor[T, Bool] = Tensor(shape, VType[Bool]).fill(value) + + /** Creates a tensor from an array of values. + * The array must have the same number of elements + * as the product of the dimensions in the shape. + */ + def fromArray(values: Array[Float]): Tensor[T, Float32] = Tensor(shape, VType[Float32]).fromArray(values) + def fromArray(values: Array[Double]): Tensor[T, Float64] = Tensor(shape, VType[Float64]).fromArray(values) def fromArray(values: Array[Byte]): Tensor[T, Int8] = Tensor(shape, VType[Int8]).fromArray(values) def fromArray(values: Array[Short]): Tensor[T, Int16] = Tensor(shape, VType[Int16]).fromArray(values) def fromArray(values: Array[Int]): Tensor[T, Int32] = Tensor(shape, VType[Int32]).fromArray(values) def fromArray(values: Array[Long]): Tensor[T, Int64] = Tensor(shape, VType[Int64]).fromArray(values) - - def fill(value: Boolean): Tensor[T, Bool] = Tensor(shape, VType[Bool]).fill(value) def fromArray(values: Array[Boolean]): Tensor[T, Bool] = Tensor(shape, VType[Bool]).fromArray(values) + /** Creates a tensor by computing each element + * using the provided function. + */ @targetName("fromFunctionFloat") def fromFunction(f: TypedIndex[T] => Float): Tensor[T, Float32] = Tensor(shape, VType[Float32]).fromArray(Tensor.tabulate(shape.dimensions, f)) @@ -104,44 +131,52 @@ object Tensor: def fromFunction(f: TypedIndex[T] => Boolean): Tensor[T, Bool] = Tensor(shape, VType[Bool]).fromArray(Tensor.tabulate(shape.dimensions, f)) + /** Factory for creating tensors with a specific shape and a given value type + * + * @param shape The shape of the tensor to create + * @param vtype The value type of the tensor to create + */ case class TypedFactory[T <: Tuple: Labels, V](shape: Shape[T], vtype: VType[V]): - // --- Boolean --- + /** @see [[DefaultsFactory.fill]] */ def fill(value: Boolean)(using IsBoolean[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) - def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) - - // --- Integer --- def fill(value: Byte)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) def fill(value: Short)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value.toInt, dtype = vtype.dtype.jaxType)) def fill(value: Int)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) def fill(value: Long)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fill(value: Float)(using IsFloating[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fill(value: Double)(using IsFloating[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + + /** @see [[DefaultsFactory.fromArray]] */ + def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) def fromArray(values: Array[Byte])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) def fromArray(values: Array[Short])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) def fromArray(values: Array[Int])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) def fromArray(values: Array[Long])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) - - // --- Floating --- - def fill(value: Float)(using IsFloating[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) - def fill(value: Double)(using IsFloating[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) def fromArray(values: Array[Float])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) def fromArray(values: Array[Double])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + /** Factory for creating tensors with the same shape and value type as another tensor. + * + * @param other The tensor to use as a template for the new tensor. + */ case class LikeFactory[T <: Tuple: Labels, V](val other: Tensor[T, V]): + /** @see [[DefaultsFactory.fill]] */ def fill(value: Boolean): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) - def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) - def fill(value: Byte): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) def fill(value: Short): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value.toInt, dtype = other.dtype.jaxType)) def fill(value: Int): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) def fill(value: Long): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fill(value: Float): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fill(value: Double): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + + /** @see [[DefaultsFactory.fromArray]] */ + def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) def fromArray(values: Array[Byte])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) def fromArray(values: Array[Short])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) def fromArray(values: Array[Int])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) def fromArray(values: Array[Long])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) - - def fill(value: Float): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) - def fill(value: Double): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) def fromArray(values: Array[Float])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) def fromArray(values: Array[Double])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) @@ -154,16 +189,25 @@ object Tensor: private[dimwit] def apply[T <: Tuple: Labels, V](jaxValue: Jax.PyDynamic): Tensor[T, V] = new Tensor(jaxValue) + /** Use the [[DefaultsFactory] to create a tensor */ def apply[T <: Tuple: Labels](shape: Shape[T]): DefaultsFactory[T] = DefaultsFactory(shape) + + /** Use the [[TypedFactory] to create a tensor */ def apply[T <: Tuple: Labels, V](shape: Shape[T], vtype: VType[V]): TypedFactory[T, V] = TypedFactory(shape, vtype) + + /** Use the [[LikeFactory] to create a tensor */ def like[T <: Tuple: Labels, V](template: Tensor[T, V]): Tensor.LikeFactory[T, V] = Tensor.LikeFactory(template) +/** Type aliases for tensors of different ranks. */ type Tensor0[V] = Tensor[EmptyTuple, V] type Tensor1[L, V] = Tensor[Tuple1[L], V] type Tensor2[L1, L2, V] = Tensor[(L1, L2), V] type Tensor3[L1, L2, L3, V] = Tensor[(L1, L2, L3), V] type Tensor4[L1, L2, L3, L4, V] = Tensor[(L1, L2, L3, L4), V] +/** Companion object for Tensors of rank 0 (scalars). + * Provides factory methods for creating tensors of rank 0 with various value types. + */ object Tensor0: given boolean2BooleanTensor[V: IsBoolean]: Conversion[Boolean, Tensor0[V]] with @@ -193,27 +237,20 @@ object Tensor0: object DefaultsFactory: def apply(value: Boolean): Tensor0[Bool] = Tensor0(VType[Bool])(value) - def apply(value: Byte): Tensor0[Int8] = Tensor0(VType[Int8])(value) def apply(value: Short): Tensor0[Int16] = Tensor0(VType[Int16])(value) def apply(value: Int): Tensor0[Int32] = Tensor0(VType[Int32])(value) def apply(value: Long): Tensor0[Int64] = Tensor0(VType[Int64])(value) - def apply(value: Float): Tensor0[Float32] = Tensor0(VType[Float32])(value) def apply(value: Double): Tensor0[Float64] = Tensor0(VType[Float64])(value) case class TypedFactory[V](vtype: VType[V]): - // --- Boolean --- def apply(value: Boolean)(using IsBoolean[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) - - // --- Integer --- def apply(value: Byte)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) def apply(value: Short)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value.toInt, dtype = vtype.dtype.jaxType)) def apply(value: Int)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) def apply(value: Long)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) - - // --- Floating --- def apply(value: Float)(using IsFloating[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) def apply(value: Double)(using IsFloating[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) @@ -225,41 +262,37 @@ object Tensor0: def apply[V](jaxValue: Jax.PyDynamic): Tensor0[V] = Tensor(jaxValue) +/** Companion object for Tensors of rank 1 (vectors). + * Provides factory methods for creating tensors of rank 1 with various value types. + */ object Tensor1: case class DefaultsFactory[L: Label](axis: Axis[L]): - // --- Boolean --- def fromArray(values: Array[Boolean]): Tensor1[L, Bool] = Tensor1(axis, VType[Bool]).fromArray(values) - - // --- Integer --- def fromArray(values: Array[Byte]): Tensor1[L, Int8] = Tensor1(axis, VType[Int8]).fromArray(values) def fromArray(values: Array[Short]): Tensor1[L, Int16] = Tensor1(axis, VType[Int16]).fromArray(values) def fromArray(values: Array[Int]): Tensor1[L, Int32] = Tensor1(axis, VType[Int32]).fromArray(values) def fromArray(values: Array[Long]): Tensor1[L, Int64] = Tensor1(axis, VType[Int64]).fromArray(values) - - // --- Floating --- def fromArray(values: Array[Float]): Tensor1[L, Float32] = Tensor1(axis, VType[Float32]).fromArray(values) def fromArray(values: Array[Double]): Tensor1[L, Float64] = Tensor1(axis, VType[Float64]).fromArray(values) case class TypedFactory[L: Label, V](axis: Axis[L], vtype: VType[V]): - // --- Boolean --- def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) - - // --- Integer --- def fromArray(values: Array[Byte])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) def fromArray(values: Array[Short])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) def fromArray(values: Array[Int])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) def fromArray(values: Array[Long])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) - - // --- Floating --- def fromArray(values: Array[Float])(using IsFloating[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) def fromArray(values: Array[Double])(using IsFloating[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) def apply[L: Label](axis: Axis[L]): DefaultsFactory[L] = DefaultsFactory(axis) def apply[L: Label, V](axis: Axis[L], vtype: VType[V]): TypedFactory[L, V] = TypedFactory(axis, vtype) +/* Companion object for Tensors of rank 2 (matrices). + * Provides factory methods for creating tensors of rank 2 with various value types. + */ object Tensor2: type Array2D[V] = Array[Array[V]] @@ -267,12 +300,10 @@ object Tensor2: case class DefaultsFactory[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2]): def fromArray(values: Array2D[Boolean]): Tensor2[L1, L2, Bool] = Tensor2(axis1, axis2, VType[Bool]).fromArray(values) - def fromArray(values: Array2D[Byte]): Tensor2[L1, L2, Int8] = Tensor2(axis1, axis2, VType[Int8]).fromArray(values) def fromArray(values: Array2D[Short]): Tensor2[L1, L2, Int16] = Tensor2(axis1, axis2, VType[Int16]).fromArray(values) def fromArray(values: Array2D[Int]): Tensor2[L1, L2, Int32] = Tensor2(axis1, axis2, VType[Int32]).fromArray(values) def fromArray(values: Array2D[Long]): Tensor2[L1, L2, Int64] = Tensor2(axis1, axis2, VType[Int64]).fromArray(values) - def fromArray(values: Array2D[Float]): Tensor2[L1, L2, Float32] = Tensor2(axis1, axis2, VType[Float32]).fromArray(values) def fromArray(values: Array2D[Double]): Tensor2[L1, L2, Float64] = Tensor2(axis1, axis2, VType[Float64]).fromArray(values) @@ -280,16 +311,11 @@ object Tensor2: private def createShape[V](values: Array2D[V]): Shape2[L1, L2] = Shape2(AxisExtent(axis1, values.length), AxisExtent(axis2, values.head.length)) - // --- Boolean --- def fromArray(values: Array2D[Boolean])(using IsBoolean[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) - - // --- Integer --- def fromArray(values: Array2D[Byte])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) def fromArray(values: Array2D[Short])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) def fromArray(values: Array2D[Int])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) def fromArray(values: Array2D[Long])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) - - // --- Floating --- def fromArray(values: Array2D[Float])(using IsFloating[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) def fromArray(values: Array2D[Double])(using IsFloating[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) @@ -301,6 +327,9 @@ object Tensor2: def eye[L: Label, V](dim: AxisExtent[L], vtype: VType[V]): Tensor2[L, Prime[L], V] = eyeImpl(dim, vtype) def diag[L: Label, V](diag: Tensor1[L, V]): Tensor2[L, Prime[L], V] = Tensor(Jax.jnp.diag(diag.jaxValue)) +/** Companion object for Tensors of rank 3. + * Provides factory methods for creating tensors of rank 3 with various value types. + */ object Tensor3: type Array3D[V] = Array[Array[Array[V]]] @@ -308,29 +337,21 @@ object Tensor3: case class DefaultsFactory[L1: Label, L2: Label, L3: Label](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3]): def fromArray(values: Array3D[Boolean]): Tensor3[L1, L2, L3, Bool] = Tensor3(axis1, axis2, axis3, VType[Bool]).fromArray(values) - def fromArray(values: Array3D[Byte]): Tensor3[L1, L2, L3, Int8] = Tensor3(axis1, axis2, axis3, VType[Int8]).fromArray(values) def fromArray(values: Array3D[Short]): Tensor3[L1, L2, L3, Int16] = Tensor3(axis1, axis2, axis3, VType[Int16]).fromArray(values) def fromArray(values: Array3D[Int]): Tensor3[L1, L2, L3, Int32] = Tensor3(axis1, axis2, axis3, VType[Int32]).fromArray(values) def fromArray(values: Array3D[Long]): Tensor3[L1, L2, L3, Int64] = Tensor3(axis1, axis2, axis3, VType[Int64]).fromArray(values) - def fromArray(values: Array3D[Float]): Tensor3[L1, L2, L3, Float32] = Tensor3(axis1, axis2, axis3, VType[Float32]).fromArray(values) def fromArray(values: Array3D[Double]): Tensor3[L1, L2, L3, Float64] = Tensor3(axis1, axis2, axis3, VType[Float64]).fromArray(values) case class TypedFactory[L1: Label, L2: Label, L3: Label, V](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3], vtype: VType[V]): private def createShape[V](values: Array3D[V]): Shape3[L1, L2, L3] = Shape3(AxisExtent(axis1, values.length), AxisExtent(axis2, values.head.length), AxisExtent(axis3, values.head.head.length)) - - // --- Boolean --- def fromArray(values: Array3D[Boolean])(using IsBoolean[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) - - // --- Integer --- def fromArray(values: Array3D[Byte])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) def fromArray(values: Array3D[Short])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) def fromArray(values: Array3D[Int])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) def fromArray(values: Array3D[Long])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) - - // --- Floating --- def fromArray(values: Array3D[Float])(using IsFloating[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) def fromArray(values: Array3D[Double])(using IsFloating[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 28ae0a8..d16d665 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -1,80 +1,73 @@ package dimwit.tensor -import scala.annotation.targetName -import scala.annotation.implicitNotFound -import scala.util.NotGiven - -import dimwit.jax.{Jax, Einops} +import dimwit.DType.* +import dimwit.DType.given +import dimwit.OnError +import dimwit.jax.Einops +import dimwit.jax.Jax +import dimwit.tensor.HasScalar import dimwit.tensor.{Label, Labels} -import dimwit.tensor.TupleHelpers.{Subset, StrictSubset, PrimeConcat} -import dimwit.tensor.ShapeTypeHelpers.{AxisRemover, AxesRemover, SharedAxisRemover, AxisReplacer, AxesConditionalRemover, WrapAxes, UnwrapAxes} +import dimwit.tensor.ShapeTypeHelpers.* +import dimwit.tensor.TensorOps.Functional.ZipVmap.{ShapesOf, TensorsOf} +import dimwit.tensor.TupleHelpers.* import dimwit.{~, `|*|`, `|+|`} import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.{Reader, Writer} -import me.shadaj.scalapy.readwrite.Writer -import me.shadaj.scalapy.readwrite.Reader - +import scala.annotation.implicitNotFound +import scala.annotation.targetName import scala.compiletime.ops.int.<= -import dimwit.tensor.TupleHelpers.{ValidationResult, CanForm, IsPermutation, ComputeMissing, CheckValid, AllOk, MissingAxis} -import dimwit.tensor.ShapeTypeHelpers.UnwrapDims -import dimwit.tensor.ShapeTypeHelpers.DimExtractor -import dimwit.tensor.ShapeTypeHelpers.AxisReplacerAll -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices -import dimwit.tensor.ShapeTypeHelpers.AxesMerger -import dimwit.OnError -import dimwit.DType.* -import dimwit.DType.given -import dimwit.tensor.HasScalar +import scala.util.NotGiven import Tuple.:* import Tuple.++ -import dimwit.tensor.ShapeTypeHelpers.MergeLabels -import dimwit.tensor.TensorOps.Functional.ZipVmap.ShapesOf -import dimwit.tensor.TensorOps.Functional.ZipVmap.TensorsOf object TensorOps: import TensorOpsUtil.* + /** Typeclass to map a type V to its corresponding DType. + */ sealed trait HasDType[V]: def dtype: DType + /** Typeclass to indicate that a type V is a numeric type + */ @implicitNotFound("Operation only valid for Numeric (Int or Float) tensors.") sealed trait IsNumber[V] - // ----------------------------------------------------------- - // Typeclasses to steer operation availability to prevent runtime errors - // ----------------------------------------------------------- - @implicitNotFound("Operation only valid for Int or Float tensors.") object IsNumber: given [V](using ev1: IsFloating[V]): IsNumber[V] = ev1 given [V](using ev2: IsInteger[V]): IsNumber[V] = ev2 @implicitNotFound("Operation only valid for Floating tensors.") + + /** Type class marker for floating point types (Float32, Float64, etc.). */ trait IsFloating[V] extends IsNumber[V], HasDType[V]: def dtype: DType object IsFloating: def apply[V](using ev: IsFloating[V]): IsFloating[V] = ev - object IsInteger: - def apply[V](using ev: IsInteger[V]): IsInteger[V] = ev - - object IsBoolean: - def apply[V](using ev: IsBoolean[V]): IsBoolean[V] = ev - + /** Type class marker for integer types */ @implicitNotFound("Operation only valid for Integer tensors.") trait IsInteger[V] extends IsNumber[V], HasDType[V]: def dtype: DType + object IsInteger: + def apply[V](using ev: IsInteger[V]): IsInteger[V] = ev + + /** Type class marker for Boolean types */ @implicitNotFound("Operation only valid for Boolean tensors.") trait IsBoolean[V] extends HasDType[V]: def dtype: DType + object IsBoolean: + def apply[V](using ev: IsBoolean[V]): IsBoolean[V] = ev + // ----------------------------------------------------------- // 1. Elementwise Operations (The Field) // Preserves Shape: T -> T @@ -85,18 +78,23 @@ object TensorOps: // General operations // --------------------------------------------------------- + /** Elementwise maximum of two tensors. */ def maximum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.maximum(t1.jaxValue, t2.jaxValue)) + + /** Elementwise minimum of two tensors. */ def minimum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue)) extension [T <: Tuple: Labels, V](t: Tensor[T, V]) - // --- Comparison --- def <(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) def <=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) def >(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) def >=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) + + /** Checks full array equality, returns true if all elements are equal */ def ===(other: Tensor[T, V]): Tensor0[Bool] = Tensor0(Jax.jnp.array_equal(t.jaxValue, other.jaxValue)) + /** Elementwise equality, returns a tensor of bools indicating which elements are equal */ def elementEquals(other: Tensor[T, V]): Tensor[T, Bool] = require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}") Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue)) @@ -200,28 +198,28 @@ object TensorOps: // --- Sum --- def sum: Tensor0[V] = Tensor0(Jax.jnp.sum(t.jaxValue)) - def sum[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.index)) - def sum[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.indices.toPythonProxy)) + def sum[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.index)) + def sum[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.sum(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Max --- def max: Tensor0[V] = Tensor0(Jax.jnp.max(t.jaxValue)) - def max[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.index)) - def max[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.indices.toPythonProxy)) + def max[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.index)) + def max[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.max(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Min --- def min: Tensor0[V] = Tensor0(Jax.jnp.min(t.jaxValue)) - def min[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.index)) - def min[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.indices.toPythonProxy)) + def min[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.index)) + def min[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Argmax --- def argmax: Tensor0[Int32] = Tensor0(Jax.jnp.argmax(t.jaxValue)) - def argmax[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) - def argmax[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argmax[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) + def argmax[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Argmin --- def argmin: Tensor0[Int32] = Tensor0(Jax.jnp.argmin(t.jaxValue)) - def argmin[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) - def argmin[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argmin[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) + def argmin[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Argsort --- def argsort: Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue)) @@ -236,23 +234,23 @@ object TensorOps: // --- Mean --- def mean: Tensor0[V] = Tensor0(Jax.jnp.mean(t.jaxValue)) - def mean[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.index)) - def mean[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.indices.toPythonProxy)) + def mean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.index)) + def mean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.mean(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Std --- def std: Tensor0[V] = Tensor0(Jax.jnp.std(t.jaxValue)) - def std[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.index)) - def std[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.indices.toPythonProxy)) + def std[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.index)) + def std[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.std(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Quantile --- def quantile(q: Float): Tensor0[V] = Tensor0(Jax.jnp.quantile(t.jaxValue, q)) - def quantile[L: Label, R <: Tuple](q: Float, axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.index)) - def quantile[Inputs <: Tuple, R <: Tuple](q: Float, axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.indices.toPythonProxy)) + def quantile[L: Label](q: Float, axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.index)) + def quantile[Inputs <: Tuple](q: Float, axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.quantile(t.jaxValue, q, axis = ev.indices.toPythonProxy)) // --- Median --- def median: Tensor0[V] = Tensor0(Jax.jnp.median(t.jaxValue)) - def median[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) - def median[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) + def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) + def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) end Reduction @@ -260,50 +258,67 @@ object TensorOps: extension [T <: Tuple: Labels, V](tensor: Tensor[T, V]) - def outerProduct[OtherShape <: Tuple: Labels, Out <: Tuple](other: Tensor[OtherShape, V])(using - primeConcat: PrimeConcat.Aux[T, OtherShape, Out], - labels: Labels[Out] - ): Tensor[Out, V] = - Tensor( - // Jax outer product flattens, reshape required - Jax.jnp.reshape( - Jax.jnp.outer(tensor.jaxValue, other.jaxValue), - (tensor.shape.dimensions ++ other.shape.dimensions).toPythonProxy - ) + /** Computes the outer product of this tensor with another tensor. + * Automatically primes the labels of the resulting tensor to avoid label collisions. + */ + def outerProduct[OtherShape <: Tuple: Labels](other: Tensor[OtherShape, V])(using + primeConcat: PrimeConcat[T, OtherShape], + labels: Labels[primeConcat.Out] + ): Tensor[primeConcat.Out, V] = Tensor( + // Jax outer product flattens, reshape required + Jax.jnp.reshape( + Jax.jnp.outer(tensor.jaxValue, other.jaxValue), + (tensor.shape.dimensions ++ other.shape.dimensions).toPythonProxy ) + ) + /** Computes the dot product of this tensor with another tensor along the specified axis. + * The axis must be present in both tensors and will be contracted (removed) from the resulting tensor. + * + * @param axis The axis along which to contract. Must be present in both tensors. + * @param other The other tensor to contract with. + */ def dot[ ContractAxis, - OtherShape <: Tuple, - R1 <: Tuple, - R2 <: Tuple, - Out <: Tuple + OtherShape <: Tuple ](axis: Axis[ContractAxis])(other: Tensor[OtherShape, V])(using - ev: AxisRemover[T, ContractAxis, R1], - evOther: AxisRemover[OtherShape, ContractAxis, R2], - primeConcat: PrimeConcat.Aux[R1, R2, Out], - labelsOut: Labels[Out] - ): Tensor[Out, V] = + ev: AxisRemover[T, ContractAxis], + evOther: AxisRemover[OtherShape, ContractAxis] + )(using + primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes], + labelsOut: Labels[primeConcat.Out] + ): Tensor[primeConcat.Out, V] = val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy) val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy) val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy) Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair)) + /** Computes the dot product of this tensor with another tensor along the specified pair of axes. + * The axes must be present in their respective tensors and will be contracted (removed) from the resulting tensor. + * + * @param axis The pair of axes along which to contract. Each axis must be present in its respective tensor. + * @param other The other tensor to contract with. + * + * Example usage: + * {{{ + * val t1: Tensor[("A", "B", "C"), Float] = ??? + * val t2: Tensor[("D", "E, "F), Float] = ??? + * val result = t1.dot(Axis["B" ~ "F])(t2) + * }}} + */ @targetName("dotOn") def dot[ ContractAxisA, ContractAxisB, - OtherShape <: Tuple, - R1 <: Tuple, - R2 <: Tuple, - Out <: Tuple + OtherShape <: Tuple ](axis: Axis[ContractAxisA ~ ContractAxisB])(other: Tensor[OtherShape, V])(using - ev: AxisRemover[T, ContractAxisA, R1], - evOther: AxisRemover[OtherShape, ContractAxisB, R2], - primeConcat: PrimeConcat.Aux[R1, R2, Out], - outLabels: Labels[Out] - ): Tensor[Out, V] = + ev: AxisRemover[T, ContractAxisA], + evOther: AxisRemover[OtherShape, ContractAxisB] + )(using + primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes], + outLabels: Labels[primeConcat.Out] + ): Tensor[primeConcat.Out, V] = val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy) val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy) val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy) @@ -312,6 +327,8 @@ object TensorOps: end Contraction + /** Convolution operations. + */ object Convolution: enum Padding: @@ -508,10 +525,10 @@ object TensorOps: extension [T <: Tuple: Labels, V](t: Tensor[T, V]) - def diagonal[L1: Label, L2: Label, R <: Tuple](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using - ev: AxesRemover[T, (L1, L2), R], - labels: Labels[R] - ): Tensor[R *: L1 *: EmptyTuple, V] = + def diagonal[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes *: L1 *: EmptyTuple, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) @@ -525,10 +542,10 @@ object TensorOps: extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - def trace[L1: Label, L2: Label, R <: Tuple](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using - ev: AxesRemover[T, (L1, L2), R], - labels: Labels[R] - ): Tensor[R, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + def trace[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) extension [L1: Label, L2: Label, V: IsNumber](t: Tensor2[L1, L2, V]) @@ -543,10 +560,10 @@ object TensorOps: def norm: Tensor0[V] = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) def inv: Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) - def det[L1: Label, L2: Label, R <: Tuple](axis1: Axis[L1], axis2: Axis[L2])(using - ev: AxesRemover[T, (L1, L2), R], - labels: Labels[R] - ): Tensor[R, V] = + def det[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2])(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = // JAX det only works on the last two axes (-2, -1). We must move the user's selected axes to the end. val moved = Jax.jnp.moveaxis( t.jaxValue, @@ -994,14 +1011,14 @@ object TensorOps: * @param unstackAxis the axis to split, specified as an Axis (e.g. Axis[Ax1]) * @return a sequence of tensors resulting from the split, each with the specified axis removed */ - def unstack[L: Label, R <: Tuple](unstackAxis: Axis[L])(using + def unstack[L: Label](unstackAxis: Axis[L])(using labels: Labels[T], - ev: AxisRemover[T, L, R], - labelR: Labels[R] - ): Seq[Tensor[R, V]] = + ev: AxisRemover[T, L], + labelR: Labels[ev.RemainingAxes] + ): Seq[Tensor[ev.RemainingAxes, V]] = val axisIdx = ev.index val unstacked = Jax.jnp.split(tensor.jaxValue, tensor.shape.dimensions(axisIdx), axis = axisIdx).as[Seq[Jax.PyDynamic]] - unstacked.map(x => Tensor[R, V](x)) + unstacked.map(x => Tensor[ev.RemainingAxes, V](x)) def chunk[splitL: Label](splitAxis: Axis[splitL], chunkSize: Int)(using labels: Labels[T], @@ -1010,79 +1027,79 @@ object TensorOps: val res = Jax.jnp.split(tensor.jaxValue, chunkSize, axis = axisIndex.index).as[Seq[Jax.PyDynamic]] res.map(x => Tensor[T, V](x)) - def slice[Inputs <: Tuple, LabelsToRemove <: Tuple, R <: Tuple]( + def slice[Inputs <: Tuple, LabelsToRemove <: Tuple]( inputs: Inputs )(using sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs], R], - labels: Labels[R] - ): Tensor[R, V] = + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = val pyIndices = tensor.calcPyIndices(inputs, ev.indices) Tensor(tensor.jaxValue.bracketAccess(pyIndices)) // Convenience overload for AxisAtIndex - def slice[L, LabelsToRemove <: Tuple, R <: Tuple]( + def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtIndex[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]], R], - labels: Labels[R] - ): Tensor[R, V] = slice(Tuple1(selector)) + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) // Convenience overload for AxisAtRange - def slice[L, LabelsToRemove <: Tuple, R <: Tuple]( + def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtRange[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtRange[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]], R], - labels: Labels[R] - ): Tensor[R, V] = slice(Tuple1(selector)) + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) // Convenience overload for AxisAtIndices - def slice[L, LabelsToRemove <: Tuple, R <: Tuple]( + def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtIndices[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndices[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]], R], - labels: Labels[R] - ): Tensor[R, V] = slice(Tuple1(selector)) + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) // Convenience overload for AxisAtTensorIndex - def slice[L, LabelsToRemove <: Tuple, R <: Tuple]( + def slice[L, LabelsToRemove <: Tuple]( selector: AxisAtTensorIndex[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTensorIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]], R], - labels: Labels[R] - ): Tensor[R, V] = slice(Tuple1(selector)) + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) // Convenience overload for AxisAtTupleIndices - def slice[L, U <: NonEmptyTuple, LabelsToRemove <: Tuple, R <: Tuple]( + def slice[L, U <: NonEmptyTuple, LabelsToRemove <: Tuple]( selector: AxisAtTupleIndices[L, U] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTupleIndices[L, U]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTupleIndices[L, U]]], R], - labels: Labels[R] - ): Tensor[R, V] = slice(Tuple1(selector)) + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTupleIndices[L, U]]]], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = slice(Tuple1(selector)) - def take[L1, L2: Label, R <: Tuple]( + def take[L1, L2: Label]( axis: Axis[L1] )( indices: Tensor1[L2, Int32] )(using - ev: AxisRemover[T, L1, R], - labels: Labels[R] - ): Tensor[Tuple.Concat[Tuple1[L2], R], V] = + ev: AxisRemover[T, L1], + labels: Labels[ev.RemainingAxes] + ): Tensor[Tuple.Concat[Tuple1[L2], ev.RemainingAxes], V] = val result = Jax.jnp.take(tensor.jaxValue, indices.jaxValue, axis = ev.index) Tensor(result) - def set[Inputs <: Tuple, LabelsToRemove <: Tuple, R <: Tuple]( + def set[Inputs <: Tuple, LabelsToRemove <: Tuple]( inputs: Inputs )(using sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs], R], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs]], labels: Labels[T] - )(value: Tensor[R, V]): Tensor[T, V] = + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = val pyIndices = tensor.calcPyIndices(inputs, ev.indices) val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value.jaxValue) Tensor(result) @@ -1092,7 +1109,7 @@ object TensorOps: inputs: Inputs )(using sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs], EmptyTuple], + ev: AxesConditionalRemover.Aux[T, LabelsToRemove, ExtractLabels[Inputs], EmptyTuple], labels: Labels[T] )(value: Float): Tensor[T, V] = val pyIndices = tensor.calcPyIndices(inputs, ev.indices) @@ -1100,40 +1117,40 @@ object TensorOps: Tensor(result) // Convenience overload for AxisAtIndex - def set[L, LabelsToRemove <: Tuple, R <: Tuple]( + def set[L, LabelsToRemove <: Tuple]( selector: AxisAtIndex[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]], R], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndex[L]]]], labels: Labels[T] - )(value: Tensor[R, V]): Tensor[T, V] = set(Tuple1(selector))(value) + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) // Convenience overload for AxisAtRange - def set[L, LabelsToRemove <: Tuple, R <: Tuple]( + def set[L, LabelsToRemove <: Tuple]( selector: AxisAtRange[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtRange[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]], R], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtRange[L]]]], labels: Labels[T] - )(value: Tensor[R, V]): Tensor[T, V] = set(Tuple1(selector))(value) + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) // Convenience overload for AxisAtIndices - def set[L, LabelsToRemove <: Tuple, R <: Tuple]( + def set[L, LabelsToRemove <: Tuple]( selector: AxisAtIndices[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtIndices[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]], R], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtIndices[L]]]], labels: Labels[T] - )(value: Tensor[R, V]): Tensor[T, V] = set(Tuple1(selector))(value) + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) // Convenience overload for AxisAtTensorIndex - def set[L, LabelsToRemove <: Tuple, R <: Tuple]( + def set[L, LabelsToRemove <: Tuple]( selector: AxisAtTensorIndex[L] )(using sliceExtractor: SliceLabelExtractor[Tuple1[AxisAtTensorIndex[L]], LabelsToRemove], - ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]], R], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Tuple1[AxisAtTensorIndex[L]]]], labels: Labels[T] - )(value: Tensor[R, V]): Tensor[T, V] = set(Tuple1(selector))(value) + )(value: Tensor[ev.RemainingAxes, V]): Tensor[T, V] = set(Tuple1(selector))(value) def rearrange[Axes <: Tuple, Status <: ValidationResult](newOrder: Axes)(using Labels[UnwrapAxes[Axes]] @@ -1298,10 +1315,10 @@ object TensorOps: val newShape = 1 +: tensor.shape.dimensions Tensor(Jax.jnp.reshape(tensor.jaxValue, newShape.toPythonProxy)) - def squeeze[L: Label, R <: Tuple](axis: Axis[L])(using - ev: AxisRemover[T, L, R], - labels: Labels[R] - ): Tensor[R, V] = + def squeeze[L: Label](axis: Axis[L])(using + ev: AxisRemover[T, L], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = require( tensor.shape.dimensions(ev.index) == 1, s"Cannot squeeze axis ${summon[Label[L]].name} of size ${tensor.shape.dimensions(ev.index)}" @@ -1335,14 +1352,14 @@ object TensorOps: type ShapesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractShape] type ValuesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractValue] - def zipvmap[L: Label, Inputs <: Tuple, OutShape <: Tuple: Labels, R <: Tuple, OutV]( + def zipvmap[L: Label, Inputs <: Tuple, OutShape <: Tuple: Labels, OutV]( axis: Axis[L] )( tensors: Inputs // This is a Tuple of Tensors )(using - ev: SharedAxisRemover[ShapesOf[Inputs], L, R] + ev: SharedAxisRemover[ShapesOf[Inputs], L] )( - f: TensorsOf[R, ValuesOf[Inputs]] => Tensor[OutShape, OutV] + f: TensorsOf[ev.RemainingAxes, ValuesOf[Inputs]] => Tensor[OutShape, OutV] ): Tensor[L *: OutShape, OutV] = val fpy = (args: py.Dynamic) => OnError.traceStack: @@ -1350,7 +1367,7 @@ object TensorOps: Tensor(jaxArr)(using LabelsImpl(labels)) val inputTuple = Tuple.fromArray(tensorList.toArray) - val result = f(inputTuple.asInstanceOf[TensorsOf[R, ValuesOf[Inputs]]]) + val result = f(inputTuple.asInstanceOf[TensorsOf[ev.RemainingAxes, ValuesOf[Inputs]]]) result.jaxValue val jaxInputs = py.Dynamic.global.tuple(tensors.toArray.map(_.asInstanceOf[Tensor[?, ?]].jaxValue).toPythonProxy) @@ -1366,18 +1383,18 @@ object TensorOps: extension [T <: Tuple: Labels, V](t: Tensor[T, V]) - def vmap[VmapAxis: Label, OuterShape <: Tuple: Labels, R <: Tuple, V2]( + def vmap[VmapAxis: Label, OuterShape <: Tuple: Labels, V2]( axis: Axis[VmapAxis] )(using - ev: AxisRemover[T, VmapAxis, R] + ev: AxisRemover[T, VmapAxis] )( - f: Tensor[R, V] => Tensor[OuterShape, V2] + f: Tensor[ev.RemainingAxes, V] => Tensor[OuterShape, V2] )(using - labels: Labels[R] + labels: Labels[ev.RemainingAxes] ): Tensor[VmapAxis *: OuterShape, V2] = val fpy = (jxpr: Jax.PyDynamic) => OnError.traceStack: - val innerTensor = Tensor[R, V](jxpr) + val innerTensor = Tensor[ev.RemainingAxes, V](jxpr) val result = f(innerTensor) result.jaxValue @@ -1405,23 +1422,23 @@ object TensorOps: ) ) - def zipvmap[L: Label, T2 <: Tuple, R <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])( + def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])( other: Tensor[T2, V] )(using - ev: SharedAxisRemover[(T, T2), L, R] + ev: SharedAxisRemover[(T, T2), L] )( - f: TensorsOf[R, (V, V)] => Tensor[OutShape, OutV] + f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV] ): Tensor[L *: OutShape, OutV] = ZipVmap.zipvmap(axis)(t, other)(f) - def vreduce[L: Label, R <: Tuple]( + def vreduce[L: Label]( axis: Axis[L] )( f: Tensor[Tuple1[L], V] => Tensor0[V] )(using - ev: AxisRemover[T, L, R], - labels: Labels[R] - ): Tensor[R, V] = + ev: AxisRemover[T, L], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = val fpy = (jxpr: Jax.PyDynamic) => OnError.traceStack: val inputTensor = Tensor[Tuple1[L], V](jxpr)