Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ val wrong = t.sum(Axis[C])
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// (repl.MdocSession.MdocApp.A, repl.MdocSession.MdocApp.B),
// repl.MdocSession.MdocApp.C, Tuple](
// repl.MdocSession.MdocApp.C, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[repl.MdocSession.MdocApp.A,
// repl.MdocSession.MdocApp.B *: EmptyTuple.type, repl.MdocSession.MdocApp.C](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[repl.MdocSession.MdocApp.B,
Expand Down Expand Up @@ -208,7 +208,7 @@ val summed = wrongAxis.sum(Axis[B]) // B not in shape!
// I found:
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// Tuple1[repl.MdocSession.MdocApp.A], repl.MdocSession.MdocApp.B, Tuple](
// Tuple1[repl.MdocSession.MdocApp.A], repl.MdocSession.MdocApp.B, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[repl.MdocSession.MdocApp.A,
// EmptyTuple.type, repl.MdocSession.MdocApp.B](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.concatRight[A², B², L]),
Expand Down Expand Up @@ -315,7 +315,7 @@ val wrong = t.sum(Axis[C])
// I found:
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// (MdocApp0.this.A, MdocApp0.this.B), MdocApp0.this.C, Tuple](
// (MdocApp0.this.A, MdocApp0.this.B), MdocApp0.this.C, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp0.this.A,
// MdocApp0.this.B *: EmptyTuple.type, MdocApp0.this.C](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp0.this.B,
Expand Down Expand Up @@ -436,7 +436,7 @@ val wrong = m1.dot(Axis[B])(m2)
// I found:
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, Tuple](
// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.C,
// MdocApp1.this.D *: EmptyTuple.type, MdocApp1.this.B](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.D,
Expand All @@ -455,7 +455,7 @@ val wrong = m1.dot(Axis[B])(m2)
// I found:
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, Tuple](
// (MdocApp1.this.C, MdocApp1.this.D), MdocApp1.this.B, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.C,
// MdocApp1.this.D *: EmptyTuple.type, MdocApp1.this.B](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp1.this.D,
Expand Down Expand Up @@ -1128,7 +1128,7 @@ val wrong = m1.dot(Axis[B])(m2) // Axis[B] not in m2
// I found:
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// (MdocApp11.this.C, MdocApp11.this.D), MdocApp11.this.B, Tuple](
// (MdocApp11.this.C, MdocApp11.this.D), MdocApp11.this.B, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.C,
// MdocApp11.this.D *: EmptyTuple.type, MdocApp11.this.B](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.D,
Expand Down Expand Up @@ -1204,7 +1204,7 @@ val wrong = t.sum(Axis[C]) // Axis[C] not in tensor
// I found:
//
// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[
// (MdocApp11.this.A, MdocApp11.this.B), MdocApp11.this.C, Tuple](
// (MdocApp11.this.A, MdocApp11.this.B), MdocApp11.this.C, R](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.A,
// MdocApp11.this.B *: EmptyTuple.type, MdocApp11.this.C](
// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.B,
Expand Down
71 changes: 57 additions & 14 deletions core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,28 @@ object ShapeTypeHelpers:

import TupleHelpers.*

/** Wraps each element of a tuple in an Axis */
type WrapAxes[T <: Tuple] <: Tuple = T match
case EmptyTuple => EmptyTuple
case a *: tail => Axis[a] *: WrapAxes[tail]

/** Unwraps each Axis in a tuple to get the label types */
type UnwrapAxes[T <: Tuple] <: Tuple = T match
case EmptyTuple => EmptyTuple
case Axis[a] *: tail => a *: UnwrapAxes[tail]
case h *: tail => h *: UnwrapAxes[tail]

/** Unwrap each AxisExtent in a tuple to get the label types */
type UnwrapDims[T <: Tuple] <: Tuple = T match
case EmptyTuple => EmptyTuple
case AxisExtent[a] *: tail => a *: UnwrapDims[tail]

/** Base trait for tracking an axis in a tensor shape */
@implicitNotFound("Axis[${Axis}] not found in Tensor[${TensorShape}]")
trait AxisInTensor[TensorShape <: Tuple, Axis]:
def index: Int

/** Finds the index of an axis in a tensor shape */
trait AxisIndex[Shape <: Tuple, Axis] extends AxisInTensor[Shape, Axis]

object AxisIndex:
Expand All @@ -49,16 +54,27 @@ object ShapeTypeHelpers:
given concatEnd[A <: Tuple, L]: AxisIndex[Tuple.Concat[A, Tuple1[L]], L] with
val index = -1

trait AxisRemover[TensorShape <: Tuple, Axis, RemainingShape <: Tuple] extends AxisInTensor[TensorShape, Axis]
/** Removing an axis from a tensor shape.
*
* RemainingAxes is the resulting shape after removing the axis.
*/
trait AxisRemover[TensorShape <: Tuple, Axis] extends AxisInTensor[TensorShape, Axis]:
type RemainingAxes <: Tuple

object AxisRemover:
type Aux[S <: Tuple, A, R <: Tuple] = AxisRemover[S, A] { type RemainingAxes = R }

given bridge[S <: Tuple, A, R <: Tuple](using
axisIndex: AxisIndex[S, A],
ev: RemoverAll.Aux[S, A *: EmptyTuple, R]
): AxisRemover[S, A, R] with
): AxisRemover.Aux[S, A, R] = new AxisRemover[S, A]:
type RemainingAxes = R
def index: Int = axisIndex.index

// Replace single axis with single axis
/** Replaces an axis in a tensor shape with another axis.
*
* NewShape is the resulting shape after replacement.
*/
trait AxisReplacer[TensorShape <: Tuple, Axis, AxisReplacement] extends AxisInTensor[TensorShape, Axis]:
type NewShape <: Tuple

Expand All @@ -72,7 +88,7 @@ object ShapeTypeHelpers:
def index: Int = idx.index
type NewShape = O

// Replace single axis with multiple axes
/** Replace Axis in given Tuple with the Axes in the AxisReplacements tuple */
trait AxisReplacerAll[TensorShape <: Tuple, Axis, AxisReplacements <: Tuple] extends AxisInTensor[TensorShape, Axis]:
type NewShape <: Tuple

Expand Down Expand Up @@ -105,10 +121,12 @@ object ShapeTypeHelpers:
type NewShape = O
def index: Int = s.index

/** Base trait for tracking multiple axes in a tensor shape */
@implicitNotFound("Axes [${Axes}] not all found in Tensor shape [${TensorShape}]")
trait AxesInTensor[TensorShape <: Tuple, Axes <: Tuple]:
def indices: List[Int]

/** Finds the indices of multiple axes in a tensor shape */
sealed trait AxisIndices[T <: Tuple, Axes <: Tuple] extends AxesInTensor[T, Axes]

object AxisIndices:
Expand All @@ -124,44 +142,65 @@ object ShapeTypeHelpers:
inline given [T <: Tuple, ToFind <: Tuple]: AxisIndices[T, ToFind] = AxisIndicesImpl[T, ToFind](indicesOfList[T, ToFind])

end AxisIndices
trait AxesRemover[TensorShape <: Tuple, Axes <: Tuple, RemainingShape <: Tuple] extends AxesInTensor[TensorShape, Axes]

/** Removes multiple axes from a tensor shape. */
trait AxesRemover[TensorShape <: Tuple, Axes <: Tuple] extends AxesInTensor[TensorShape, Axes]:
type RemainingAxes <: Tuple

object AxesRemover:
type Aux[T <: Tuple, Axes <: Tuple, R <: Tuple] = AxesRemover[T, Axes] { type RemainingAxes = R }

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding Aux Pattern?


given bridge[T <: Tuple, Axes <: Tuple, R <: Tuple](using
idx: AxisIndices[T, Axes],
ev: RemoverAll.Aux[T, Axes, R]
): AxesRemover[T, Axes, R] with
): AxesRemover.Aux[T, Axes, R] = new AxesRemover[T, Axes]:
type RemainingAxes = R
def indices: List[Int] = idx.indices

trait AxesConditionalRemover[TensorShape <: Tuple, RemovedAxis <: Tuple, IndexAxes <: Tuple, RemainingShape <: Tuple] extends AxesInTensor[TensorShape, IndexAxes]
/** Removes [[RemovedAxis]] from a tensor shape while computing runtime indices
* for [[IndexAxes]].
*/
trait AxesConditionalRemover[TensorShape <: Tuple, RemovedAxis <: Tuple, IndexAxes <: Tuple] extends AxesInTensor[TensorShape, IndexAxes]:
type RemainingAxes <: Tuple

object AxesConditionalRemover:
type Aux[T <: Tuple, RA <: Tuple, IA <: Tuple, R <: Tuple] = AxesConditionalRemover[T, RA, IA] { type RemainingAxes = R }

given bridge[T <: Tuple, RemovedAxis <: Tuple, IndexAxes <: Tuple, R <: Tuple](using
idx: AxisIndices[T, IndexAxes],
ev: RemoverAll.Aux[T, RemovedAxis, R]
): AxesConditionalRemover[T, RemovedAxis, IndexAxes, R] with
): AxesConditionalRemover.Aux[T, RemovedAxis, IndexAxes, R] = new AxesConditionalRemover[T, RemovedAxis, IndexAxes]:
type RemainingAxes = R
def indices = idx.indices

/** Removes a shared axis from multiple tensor shapes while computing runtime indices
* for the remaining axes.
*/
@implicitNotFound("Axis[${Axis}] not found in ${Shapes}}")
trait SharedAxisRemover[Shapes <: Tuple, Axis, Sliced <: Tuple]:
trait SharedAxisRemover[Shapes <: Tuple, Axis]:
type RemainingAxes <: Tuple
def indices: List[Int]
def shapesLabels: List[List[String]]

object SharedAxisRemover:
type Aux[S <: Tuple, A, O <: Tuple] = SharedAxisRemover[S, A] { type RemainingAxes = O }

given empty[Axis]: SharedAxisRemover[EmptyTuple, Axis, EmptyTuple] with
given empty[Axis]: SharedAxisRemover.Aux[EmptyTuple, Axis, EmptyTuple] = new SharedAxisRemover[EmptyTuple, Axis]:
type RemainingAxes = EmptyTuple
def indices = Nil
def shapesLabels = Nil
type Sliced = EmptyTuple

given cons[H <: Tuple, T <: Tuple, Axis, R <: Tuple, TailOut <: Tuple](using
evH: AxisRemover[H, Axis, R],
evT: SharedAxisRemover[T, Axis, TailOut],
evH: AxisRemover.Aux[H, Axis, R],
evT: SharedAxisRemover.Aux[T, Axis, TailOut],
rLabels: Labels[R]
): SharedAxisRemover[H *: T, Axis, R *: TailOut] with
): SharedAxisRemover.Aux[H *: T, Axis, R *: TailOut] = new SharedAxisRemover[H *: T, Axis]:
type RemainingAxes = R *: TailOut
def indices = evH.index :: evT.indices
def shapesLabels = List(rLabels.names) ++ evT.shapesLabels

/** Extracts the dimensions of a tensor shape into a Map of label names to sizes.
*/
trait DimExtractor[T]:
def extract(t: T): Map[String, Int]

Expand All @@ -181,6 +220,10 @@ object ShapeTypeHelpers:
def extract(t: AxisExtent[L]) =
Map(label.name -> t.size)

/** Merges multiple axes in a tensor shape into a single axis.
*
* NewShape is the resulting shape after merging the axes.
*/
@implicitNotFound("Cannot merge axes ${ToMerge} in shape ${S}. Ensure all axes exist.")
trait AxesMerger[S <: Tuple, ToMerge <: Tuple]:
type NewShape <: Tuple
Expand Down
Loading
Loading