Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -1166,9 +1166,9 @@ val wrong = t1 +! t2
// Cannot broadcast tensors of shapes Tuple1[MdocApp11.this.A] and Tuple1[MdocApp11.this.A]. If same shape no broadcasting allowed!.
// I found:
//
// dimwit.tensor.TensorOpsUtil.Broadcast.broadcastLeft[Tuple1[MdocApp11.this.A],
// Tuple1[MdocApp11.this.A], (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)]
// (
// dimwit.tensor.tensorops.TensorOpsUtil.Broadcast.broadcastLeft[
// Tuple1[MdocApp11.this.A], Tuple1[MdocApp11.this.A],
// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)](
// dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type](
// this.A.derived$Label, dimwit.tensor.Labels.namesOfEmpty),
// dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type](
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/dimwit/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ package object dimwit:

// Export operations
export dimwit.tensor.TensorOps.*
export dimwit.tensor.ValueOps.*

// Export devices
export dimwit.hardware.Device
Expand Down
1,615 changes: 22 additions & 1,593 deletions core/src/main/scala/dimwit/tensor/TensorOps.scala

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions core/src/main/scala/dimwit/tensor/ValueOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package dimwit.tensor

import dimwit.tensor.TensorOps.IsNumber
import dimwit.tensor.TensorOps.IsFloating
import dimwit.tensor.tensorops.ElementWiseOps.add
import dimwit.tensor.tensorops.ElementWiseOps.subtract
import dimwit.tensor.tensorops.ElementWiseOps.multiply
import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast
import dimwit.tensor.tensorops.ElementWiseOps.divide

object ValueOps:

extension [V: IsNumber](t: Tensor0[V])

def +(t2: Tensor0[V]): Tensor0[V] = TensorOps.add(t, t2)
def -(t2: Tensor0[V]): Tensor0[V] = TensorOps.subtract(t, t2)
def *(t2: Tensor0[V]): Tensor0[V] = TensorOps.multiply(t, t2)

extension [V: IsFloating](t: Tensor0[V])

def /(scalar: Tensor0[V]): Tensor0[V] = TensorOps.divide(t, scalar)

extension (scalar: Float)

def +[V: IsNumber](t: Tensor0[V]): Tensor0[V] = add(Tensor0.likeDType(t)(scalar), t)
def +![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(add)

def -[V: IsNumber](t: Tensor0[V]): Tensor0[V] = subtract(Tensor0.likeDType(t)(scalar), t)
def -![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(subtract)

def *[V: IsNumber](t: Tensor0[V]): Tensor0[V] = multiply(Tensor0.likeDType(t)(scalar), t)
def *![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(multiply)

extension (scalar: Float)

def /[V: IsFloating](t: Tensor0[V]): Tensor0[V] = divide(Tensor0.likeDType(t)(scalar), t)
def /![T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(divide)
98 changes: 98 additions & 0 deletions core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package dimwit.tensor.tensorops

import dimwit.tensor.Tensor
import dimwit.tensor.Labels
import dimwit.jax.Jax
import dimwit.tensor.DType.Bool
import dimwit.tensor.Tensor0
import dimwit.tensor.TensorOps.IsBoolean
import dimwit.tensor.VType
import dimwit.tensor.DType.Int32
import dimwit.tensor.DType.Float32
import dimwit.tensor.TensorOps.IsInteger
import dimwit.tensor.TensorOps.IsFloating
import dimwit.tensor.TensorOps.IsNumber
import dimwit.tensor.Label
import dimwit.tensor.ShapeTypeHelpers.AxisRemover
import dimwit.tensor.ShapeTypeHelpers.AxesRemover
import dimwit.tensor.Axis
import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes
import dimwit.tensor.ShapeTypeHelpers.AxisIndex
import dimwit.tensor.ShapeTypeHelpers.AxisIndices

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.readwrite.{Reader, Writer}
import dimwit.tensor.TupleHelpers.PrimeConcat
import scala.annotation.targetName

object ContractionOps:

extension [T <: Tuple: Labels, V](tensor: Tensor[T, V])

/** Computes the outer product of this tensor with another tensor.
* Automatically primes the labels of the resulting tensor to avoid label collisions.
*/
def outerProduct[OtherShape <: Tuple: Labels](other: Tensor[OtherShape, V])(using
primeConcat: PrimeConcat[T, OtherShape],
labels: Labels[primeConcat.Out]
): Tensor[primeConcat.Out, V] = Tensor(
// Jax outer product flattens, reshape required
Jax.jnp.reshape(
Jax.jnp.outer(tensor.jaxValue, other.jaxValue),
(tensor.shape.dimensions ++ other.shape.dimensions).toPythonProxy
)
)

/** Computes the dot product of this tensor with another tensor along the specified axis.
* The axis must be present in both tensors and will be contracted (removed) from the resulting tensor.
*
* @param axis The axis along which to contract. Must be present in both tensors.
* @param other The other tensor to contract with.
*/
def dot[
ContractAxis,
OtherShape <: Tuple
](axis: Axis[ContractAxis])(other: Tensor[OtherShape, V])(using
ev: AxisRemover[T, ContractAxis],
evOther: AxisRemover[OtherShape, ContractAxis]
)(using
primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes],
labelsOut: Labels[primeConcat.Out]
): Tensor[primeConcat.Out, V] =
val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy)
val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy)
val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy)

Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair))

/** Computes the dot product of this tensor with another tensor along the specified pair of axes.
* The axes must be present in their respective tensors and will be contracted (removed) from the resulting tensor.
*
* @param axis The pair of axes along which to contract. Each axis must be present in its respective tensor.
* @param other The other tensor to contract with.
*
* Example usage:
* {{{
* val t1: Tensor[("A", "B", "C"), Float] = ???
* val t2: Tensor[("D", "E, "F), Float] = ???
* val result = t1.dot(Axis["B" ~ "F])(t2)
* }}}
*/
@targetName("dotOn")
def dot[
ContractAxisA,
ContractAxisB,
OtherShape <: Tuple
](axisPair: (Axis[ContractAxisA], Axis[ContractAxisB]))(other: Tensor[OtherShape, V])(using
ev: AxisRemover[T, ContractAxisA],
evOther: AxisRemover[OtherShape, ContractAxisB]
)(using
primeConcat: PrimeConcat[ev.RemainingAxes, evOther.RemainingAxes],
outLabels: Labels[primeConcat.Out]
): Tensor[primeConcat.Out, V] =
val axesTuple1 = Jax.Dynamic.global.tuple(Seq(ev.index).toPythonProxy)
val axesTuple2 = Jax.Dynamic.global.tuple(Seq(evOther.index).toPythonProxy)
val axesPair = Jax.Dynamic.global.tuple(Seq(axesTuple1, axesTuple2).toPythonProxy)

Tensor(Jax.jnp.tensordot(tensor.jaxValue, other.jaxValue, axes = axesPair))
214 changes: 214 additions & 0 deletions core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package dimwit.tensor.tensorops

import dimwit.tensor.Tensor
import dimwit.tensor.Labels
import dimwit.jax.Jax
import dimwit.tensor.DType.Bool
import dimwit.tensor.Tensor0
import dimwit.tensor.TensorOps.IsBoolean
import dimwit.tensor.VType
import dimwit.tensor.DType.Int32
import dimwit.tensor.DType.Float32
import dimwit.tensor.TensorOps.IsInteger
import dimwit.tensor.TensorOps.IsFloating
import dimwit.tensor.TensorOps.IsNumber
import dimwit.tensor.Label
import dimwit.tensor.ShapeTypeHelpers.AxisRemover
import dimwit.tensor.ShapeTypeHelpers.AxesRemover
import dimwit.tensor.Axis
import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes
import dimwit.tensor.ShapeTypeHelpers.AxisIndex
import dimwit.tensor.ShapeTypeHelpers.AxisIndices

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.readwrite.{Reader, Writer}
import dimwit.tensor.AxisExtent
import dimwit.tensor.TensorOps.swap

object ConvolutionOps:

enum Padding:
case SAME, VALID

type Stride1[S1] = AxisExtent[S1]

extension [S1: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: InChannel *: EmptyTuple, V])

def conv1d[OutChannel: Label](
kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride1[S1] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: OutChannel *: EmptyTuple, V] =
require(
input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]),
s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels"
)
val strides = stride match
case s: Int => Seq(s)
case ae: AxisExtent[S1] => Seq(ae.size)
// JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input.
val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim
val convResult = Jax.lax.conv_general_dilated(
lhs = batchInput,
rhs = kernel.jaxValue,
window_strides = strides.toPythonProxy,
padding = padding.toString,
dimension_numbers = py.Dynamic.global.tuple(Seq("NHC", "HIO", "NHC").toPythonProxy)
)
val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim
Tensor(unbatchedRes)

extension [S1: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: OutChannel *: EmptyTuple, V])

def transposeConv1d[InChannel: Label](
kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride1[S1] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: InChannel *: EmptyTuple, V] =
require(
input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]),
s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}"
)
val strides = stride match
case s: Int => Seq(s)
case ex: AxisExtent[S1] => Seq(ex.size)

// kernel -> kernal adjoint: swap in/out channels and flip spatial dims
var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue
kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1

val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim
val convResult = Jax.lax.conv_transpose(
lhs = batchInput,
rhs = kernelAdjoint,
strides = strides.toPythonProxy,
padding = padding.toString,
dimension_numbers = py.Dynamic.global.tuple(Seq("NHC", "HIO", "NHC").toPythonProxy)
)
val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim
Tensor(unbatchedRes)

type Stride2[S1, S2] = (AxisExtent[S1], AxisExtent[S2])

extension [S1: Label, S2: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V])

def conv2d[OutChannel: Label](
kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride2[S1, S2] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V] =
require(
input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]),
s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels"
)
val strides = stride match
case s: Int => Seq(s, s)
case (ae1, ae2) => Seq(ae1.size, ae2.size)
// JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input.
val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim
val convResult = Jax.lax.conv_general_dilated(
lhs = batchInput,
rhs = kernel.jaxValue,
window_strides = strides.toPythonProxy,
padding = padding.toString,
dimension_numbers = py.Dynamic.global.tuple(Seq("NHWC", "HWIO", "NHWC").toPythonProxy)
)
val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim
Tensor(unbatchedRes)

extension [S1: Label, S2: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V])

def transposeConv2d[InChannel: Label](
kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride2[S1, S2] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V] =
require(
input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]),
s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}"
)

// JAX requires input and kernel to have same rank. Add dummy batch dim if needed.
val strides = stride match
case s: Int => Seq(s, s)
case (ae1, ae2) => Seq(ae1.size, ae2.size)

// kernel -> kernal adjoint: swap in/out channels and flip spatial dims
var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue
kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1
kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 1) // flip S2

val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim
val convResult = Jax.lax.conv_transpose(
lhs = batchInput,
rhs = kernelAdjoint,
strides = strides.toPythonProxy,
padding = padding.toString,
dimension_numbers = py.Dynamic.global.tuple(Seq("NHWC", "HWIO", "NHWC").toPythonProxy)
)
val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim
Tensor(unbatchedRes)

type Stride3[S1, S2, S3] = (AxisExtent[S1], AxisExtent[S2], AxisExtent[S3])

extension [S1: Label, S2: Label, S3: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V])

def conv3d[OutChannel: Label](
kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride3[S1, S2, S3] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V] =
require(
input.shape(Axis[InChannel]) == kernel.shape(Axis[InChannel]),
s"Input channels mismatch: input has ${input.shape(Axis[InChannel])} channels, kernel expects ${kernel.shape(Axis[InChannel])} channels"
)
val strides = stride match
case s: Int => Seq(s, s, s)
case (dim1, dim2, dim3) => Seq(dim1.size, dim2.size, dim3.size)

// JAX requires input and kernel to have same rank, so we must add (and remove) dummy dim to input.
// 3D Layout: NDHWC (Batch, Depth, Height, Width, Channel)
val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim
val convResult = Jax.lax.conv_general_dilated(
lhs = batchInput,
rhs = kernel.jaxValue,
window_strides = strides.toPythonProxy,
padding = padding.toString,
dimension_numbers = py.Dynamic.global.tuple(Seq("NDHWC", "DHWIO", "NDHWC").toPythonProxy)
)
val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim
Tensor(unbatchedRes)

extension [S1: Label, S2: Label, S3: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V])

def transposeConv3d[InChannel: Label](
kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride3[S1, S2, S3] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V] =
require(
input.shape(Axis[OutChannel]) == kernel.shape(Axis[OutChannel]),
s"Input channels mismatch: input has ${input.shape(Axis[OutChannel])} channels (OutChannel), kernel expects ${kernel.shape(Axis[OutChannel])}"
)

val strides = stride match
case s: Int => Seq(s, s, s)
case (ae1, ae2, ae3) => Seq(ae1.size, ae2.size, ae3.size)

// kernel -> kernel adjoint: swap in/out channels and flip all spatial dims
var kernelAdjoint = kernel.swap(Axis[InChannel], Axis[OutChannel]).jaxValue
kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 0) // flip S1 (Depth)
kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 1) // flip S2 (Height)
kernelAdjoint = Jax.jnp.flip(kernelAdjoint, axis = 2) // flip S3 (Width)

val batchInput = Jax.jnp.expand_dims(input.jaxValue, axis = 0) // add dummy dim
val convResult = Jax.lax.conv_transpose(
lhs = batchInput,
rhs = kernelAdjoint,
strides = strides.toPythonProxy,
padding = padding.toString,
dimension_numbers = py.Dynamic.global.tuple(Seq("NDHWC", "DHWIO", "NDHWC").toPythonProxy)
)
val unbatchedRes = Jax.jnp.squeeze(convResult, axis = 0) // remove dummy dim
Tensor(unbatchedRes)
Loading
Loading