Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/src/main/scala/dimwit/autodiff/TensorTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package dimwit.autodiff

import dimwit.jax.Jax
import dimwit.tensor.*
import scala.compiletime.*
import scala.deriving.*

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters

import scala.compiletime.*
import scala.deriving.*

/** A typeclass for structures that can be represented as a tree of tensors,
* which can be mapped over. Most often, a tensor tree is used to structure
* parameters of a model.
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/dimwit/tensor/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class Tensor[T <: Tuple: Labels, V] private[dimwit] (
/** The device on which the tensor is stored. */
lazy val device: Device = Device(jaxValue.device)

/** Converts the tensor to a different value type, if compatible. */
/** Converts the tensor to the given vtype.
*/
def asType[V2](vtype: VType[V2]): Tensor[T, V2] = new Tensor(Jax.jnp.astype(jaxValue, JaxDType.jaxDtype(vtype.dtype)))

/** Moves the tensor to a different device. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import me.shadaj.scalapy.readwrite.Writer

import scala.annotation.targetName

/** Provides extension methods for tensor contraction operations,
* including outer products and dot products.
*/
object ContractionOps:

extension [T <: Tuple: Labels, V](tensor: Tensor[T, V])
Expand Down Expand Up @@ -62,7 +65,7 @@ object ContractionOps:
* {{{
* val t1: Tensor[("A", "B", "C"), Float] = ???
* val t2: Tensor[("D", "E, "F), Float] = ???
* val result = t1.dot(Axis["B" ~ "F])(t2)
* val result = t1.dot(Axis[A]->Axis[D])(t2)
* }}}
*/
@targetName("dotOn")
Expand Down
55 changes: 54 additions & 1 deletion core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,35 @@ import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.readwrite.Writer

/** Provides extension methods for convolution operations on tensors.
* Convolution operations are restricted to 1D, 2D and 3D convolutions,
* and support both standard and transposed convolutions.
*/
object ConvolutionOps:

/** Padding options for convolution operations.
* SAME: Output size is the same as input size (with appropriate padding).
* VALID: No padding, output size is reduced based on kernel size.
*
* Refer to JAX documentation for more details on padding behavior.
* https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html
*/
enum Padding:
case SAME, VALID

type Stride1[S1] = AxisExtent[S1]

extension [S1: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: InChannel *: EmptyTuple, V])

/** Computes the 1D convolution of this tensor with the specified kernel tensor.
*
* @param kernel - The convolution kernel
* @param stride - Stride for the convolution.
* @param padding - Padding mode for the convolution.
* @return A new tensor representing the result of the convolution operation.
*/
def conv1d[OutChannel: Label](
kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V],
kernel: Tensor[(S1, InChannel, OutChannel), V],
stride: Stride1[S1] | Int = 1,
padding: Padding = Padding.SAME
): Tensor[S1 *: OutChannel *: EmptyTuple, V] =
Expand All @@ -48,6 +66,13 @@ object ConvolutionOps:

extension [S1: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: OutChannel *: EmptyTuple, V])

/** Computes the transposed 1D convolution of
* this tensor with the specified kernel tensor.
* @param kernel - The convolution kernel
* @param stride - Stride for the convolution.
* @param padding - Padding mode for the convolution.
* @return A new tensor representing the result of the transposed convolution operation.
*/
def transposeConv1d[InChannel: Label](
kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride1[S1] | Int = 1,
Expand Down Expand Up @@ -80,6 +105,13 @@ object ConvolutionOps:

extension [S1: Label, S2: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V])

/** Computes the 2D convolution of this tensor with the specified kernel tensor.
*
* @param kernel - The convolution kernel tensor with shape (S1, S2, InChannel, OutChannel).
* @param stride - Stride for the convolution.
* @param padding - Padding mode for the convolution.
* @return A new tensor representing the result of the convolution operation.
*/
def conv2d[OutChannel: Label](
kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride2[S1, S2] | Int = 1,
Expand All @@ -106,6 +138,13 @@ object ConvolutionOps:

extension [S1: Label, S2: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V])

/** Computes the transposed 2D convolution of this tensor with the specified kernel tensor.
*
* @param kernel - The convolution kernel tensor with shape (S1, S2, InChannel, OutChannel).
* @param stride - Stride for the convolution.
* @param padding - Padding mode for the convolution.
* @return A new tensor representing the result of the transposed convolution operation.
*/
def transposeConv2d[InChannel: Label](
kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride2[S1, S2] | Int = 1,
Expand Down Expand Up @@ -141,6 +180,13 @@ object ConvolutionOps:

extension [S1: Label, S2: Label, S3: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V])

/** Computes the 3D convolution of this tensor with the specified kernel tensor.
*
* @param kernel - The convolution kernel tensor
* @param stride - Stride for the convolution.
* @param padding - Padding mode for the convolution.
* @return A new tensor representing the result of the convolution operation.
*/
def conv3d[OutChannel: Label](
kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride3[S1, S2, S3] | Int = 1,
Expand Down Expand Up @@ -169,6 +215,13 @@ object ConvolutionOps:

extension [S1: Label, S2: Label, S3: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V])

/** Computes the transposed 3D convolution of this tensor with the specified kernel tensor.
*
* @param kernel - The convolution kernel tensor
* @param stride - Stride for the convolution.
* @param padding - Padding mode for the convolution.
* @return A new tensor representing the result of the transposed convolution operation.
*/
def transposeConv3d[InChannel: Label](
kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V],
stride: Stride3[S1, S2, S3] | Int = 1,
Expand Down
64 changes: 53 additions & 11 deletions core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import dimwit.tensor.VType
import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast

object ElementWiseOps:

// ---------------------------------------------------------
// General operations
// General operations on any tensor type
// ---------------------------------------------------------

/** Elementwise maximum of two tensors. */
Expand All @@ -25,6 +26,7 @@ object ElementWiseOps:
/** Elementwise minimum of two tensors. */
def minimum[T <: Tuple: Labels, V](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.minimum(t1.jaxValue, t2.jaxValue))

// extension methods for comparisons
extension [T <: Tuple: Labels, V](t: Tensor[T, V])

def <(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue))
Expand All @@ -40,33 +42,64 @@ object ElementWiseOps:
require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}")
Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue))

/** Casts the elements of this tensor to a tensor of type Bool. */
def asBool: Tensor[T, Bool] = t.asType(VType[Bool])

/** Casts the elements of this tensor to a tensor of the given boolean type.
* @param vtype the type to cast to
*/
def asBoolean[NewV: IsBoolean](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype)

/** Cast the elements of this tensor to a tensor of type Int32. */
def asInt32: Tensor[T, Int32] = t.asType(VType[Int32])

/** Casts the elements of this tensor to a tensor of the given integer type.
*
* @param vtype - the type to cast to
*/
def asInt[NewV: IsInteger](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype)

/** Casts the elements of this tensor to a tensor of type Float32. */
def asFloat32: Tensor[T, Float32] = t.asType(VType[Float32])
def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype)

// ---------------------------------------------------------
// IsNumber operations (IsFloat or IsInt)
// ---------------------------------------------------------
/** Casts the elements of this tensor to a tensor of the given floating point type.
*
* @param vtype - the type to cast to
*/
def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype)

/** Performs element-wise addition of two tensors of the same shape and type. */
def add[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue))
def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue))

/** Adds a scalar tensor to each element of a tensor. */
def addScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], s: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, s.jaxValue))

/** Returns a new tensor with each element negated. */
def negate[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.negative(t.jaxValue))

/** Subtracts one tensor from another of the same shape and type, returning a new tensor. */
def subtract[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t1.jaxValue, t2.jaxValue))
def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, t2.jaxValue))

def multiply[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue))
def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue))
/** Subtracts a scalar tensor from each element of a tensor. */
def subtractScalar[T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V], s: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.subtract(t.jaxValue, s.jaxValue))

/** Multiplies two tensors of the same shape and type element-wise, returning a new tensor. */
def multiply[T <: Tuple: Labels, V: IsNumber](
t1: Tensor[T, V],
t2: Tensor[T, V]
): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, t2.jaxValue))

/** Multiplies each element of a tensor by a scalar tensor, returning a new tensor. */
def multiplyScalar[T <: Tuple: Labels, V: IsNumber](t1: Tensor[T, V], s: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.multiply(t1.jaxValue, s.jaxValue))

// extension methods for the binary operations on two tensors
extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])

def +(other: Tensor[T, V]): Tensor[T, V] = add(t, other)
def -(other: Tensor[T, V]): Tensor[T, V] = subtract(t, other)
def *(other: Tensor[T, V]): Tensor[T, V] = multiply(t, other)

// extension methods for the scalar operations.
extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])

def +![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(add)
Expand All @@ -77,18 +110,24 @@ object ElementWiseOps:
def *![O <: Tuple](other: Tensor[O, V])(using bc: Broadcast[T, O, V]): Tensor[bc.Out, V] = bc.applyTo(t, other)(multiply)
def scale(other: Tensor0[V]): Tensor[T, V] = multiplyScalar(t, other)

// extension methods
extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])
def abs: Tensor[T, V] = Tensor(Jax.jnp.abs(t.jaxValue))
def sign: Tensor[T, V] = Tensor(Jax.jnp.sign(t.jaxValue))
def clip(min: Tensor0[V], max: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.clip(t.jaxValue, min.jaxValue, max.jaxValue))
def pow(n: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.power(t.jaxValue, n.jaxValue))

// ---------------------------------------------------------
// IsFloat operations
// Operations on Floating tensors
// ---------------------------------------------------------

/** Divides two tensors of the same shape and type element-wise, returning a new tensor. */
def divide[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue))

/** Divides each element of a tensor by a scalar tensor, returning a new tensor. */
def divideScalar[T <: Tuple: Labels, V: IsFloating](t1: Tensor[T, V], t2: Tensor0[V]): Tensor[T, V] = Tensor(Jax.jnp.divide(t1.jaxValue, t2.jaxValue))

// extension methods on floating tensors
extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V])

def /(other: Tensor[T, V]): Tensor[T, V] = divide(t, other)
Expand All @@ -115,10 +154,13 @@ object ElementWiseOps:
// ---------------------------------------------------------
// IsBoolean operations
// ---------------------------------------------------------

extension [T <: Tuple: Labels, V: IsBoolean](t: Tensor[T, V])

/** returns true if all elements of the tensor are true, false otherwise */
def all: Tensor0[V] = Tensor0(Jax.jnp.all(t.jaxValue))

/** return true if any element of the tensor is true, false otherwise */
def any: Tensor0[V] = Tensor0(Jax.jnp.any(t.jaxValue))

/** returns a tensor of the same shape with each element negated (logical NOT) */
def unary_! : Tensor[T, V] = Tensor(Jax.jnp.logical_not(t.jaxValue))
67 changes: 54 additions & 13 deletions core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ import me.shadaj.scalapy.readwrite.Reader
import me.shadaj.scalapy.readwrite.Writer

object FunctionalOps:
// -----------------------------------------------------------
// 5. Functional Operations (Higher Order)
// Lifting functions over axes
// -----------------------------------------------------------

object ZipVmap:

Expand All @@ -38,6 +34,23 @@ object FunctionalOps:
type ShapesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractShape]
type ValuesOf[Tensors <: Tuple] = Tuple.Map[Tensors, ExtractValue]

/** Zips the given given tensors along the specified axis
* and applies the function `f` to the zipped tensors.
*
* @param axis The axis along which to zip the tensors.
* @param tensors A tuple of tensors to be zipped.
* @param f A function that takes a tuple of tensors (with the specified axis removed) and returns a new tensor.
* @return A new tensor resulting from applying `f`
*
* Example usage:
* {{{
* val tensor1: Tensor[(A, B), Int] = ...
* val tensor2: Tensor[(A, B), Int] = ...
* val result: Tensor[(A, C), Int] = ZipVmap.zipvmap(Axis[A])(tensor1, tensor2) { case (t1, t2) =>
* // Perform operations on t1 and t2, which are tensors with axis A removed, and return a new tensor
* ...
* }
*/
def zipvmap[L: Label, Inputs <: Tuple, OutShape <: Tuple: Labels, OutV](
axis: Axis[L]
)(
Expand Down Expand Up @@ -69,6 +82,29 @@ object FunctionalOps:

extension [T <: Tuple: Labels, V](t: Tensor[T, V])

/** Zips the current tensor with another tensor along the specified axis
* and applies the function `f` to the zipped tensors.
*
* @param axis The axis along which to zip the tensors.
* @param other The other tensor to be zipped with the current tensor.
* @param f A function that takes a tuple of tensors (with the specified axis removed) and returns a new tensor.
* @return A new tensor resulting from applying `f` to the zipped tensors.
*/
def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])(
other: Tensor[T2, V]
)(using
ev: SharedAxisRemover[(T, T2), L]
)(
f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV]
): Tensor[L *: OutShape, OutV] =
ZipVmap.zipvmap(axis)(t, other)(f)

/** Vectorized mapping over a specified axis of the tensor.
*
* @param axis The axis along which to apply the function `f`.
* @param f A function that takes a tensor with the specified axis removed and returns a new tensor.
* @return A new tensor resulting from applying `f` to each slice along the specified axis.
*/
def vmap[VmapAxis: Label, OuterShape <: Tuple: Labels, V2](
axis: Axis[VmapAxis]
)(using
Expand All @@ -86,6 +122,13 @@ object FunctionalOps:

Tensor(Jax.jax_helper.vmap(fpy, ev.index)(t.jaxValue))

/** Apply a function independently to each 1D slice along a labeled axis.
*
* @param axis The axis along which to apply the function `f`.
* @param f A function f that is applied to each L-axis slice; it may rename that axis to NewL and change the element type.
*
* @return A new tensor resulting from applying `f` to each slice along the specified axis.
*/
def vapply[L: Label, NewL, R <: Tuple, NewV](
axis: Axis[L]
)(
Expand All @@ -108,15 +151,13 @@ object FunctionalOps:
)
)

def zipvmap[L: Label, T2 <: Tuple, OutShape <: Tuple: Labels, OutV](axis: Axis[L])(
other: Tensor[T2, V]
)(using
ev: SharedAxisRemover[(T, T2), L]
)(
f: TensorsOf[ev.RemainingAxes, (V, V)] => Tensor[OutShape, OutV]
): Tensor[L *: OutShape, OutV] =
ZipVmap.zipvmap(axis)(t, other)(f)

/** Reduce a tensor along a labeled axis by applying a function to each 1D slice.
*
* @param axis The axis along which to reduce the tensor.
* @param f A function f that is applied to each L-axis slice; it must return a scalar (Tensor0).
*
* @return A new tensor resulting from reducing the specified axis.
*/
def vreduce[L: Label](
axis: Axis[L]
)(
Expand Down
Loading
Loading