Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions core/src/main/scala/dimwit/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ package object dimwit:
System.gc()
Jax.gc()

@targetName("On")
infix trait ~[A, B]
object `~`:
given [A, B](using labelA: Label[A], labelB: Label[B]): Label[A ~ B] with
val name: String = s"${labelA.name}_on_${labelB.name}"

/** Combination of dimensions / labels
*
* Mentally think of this as the "product" of two dimensions.
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/dimwit/tensor/TensorOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import dimwit.tensor.{Label, Labels}
import dimwit.tensor.ShapeTypeHelpers.*
import dimwit.tensor.TensorOps.Functional.ZipVmap.{ShapesOf, TensorsOf}
import dimwit.tensor.TupleHelpers.*
import dimwit.{~, `|*|`, `|+|`}
import dimwit.{`|*|`, `|+|`}

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
Expand Down Expand Up @@ -312,7 +312,7 @@ object TensorOps:
ContractAxisA,
ContractAxisB,
OtherShape <: Tuple
](axis: Axis[ContractAxisA ~ ContractAxisB])(other: Tensor[OtherShape, V])(using
](axisPair: (Axis[ContractAxisA], Axis[ContractAxisB]))(other: Tensor[OtherShape, V])(using
ev: AxisRemover[T, ContractAxisA],
evOther: AxisRemover[OtherShape, ContractAxisB]
)(using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class TensorOpsContractionSuite extends DimwitTest:
)

describe("dot on different axis labels (A1 ~ A2)"):
it("Tensor2[A, B] and Tensor2[C, D] using Axis mapping (A ~ C)"):
it("Tensor2[A, B] and Tensor2[C, D] using Axis mapping (Axis[A] -> Axis[C])"):
val mCD = m2.relabelAll((Axis[C], Axis[D]))

val res = m1.dot(Axis[A ~ C])(mCD)
val res = m1.dot(Axis[A] -> Axis[C])(mCD)

res.shape.labels shouldBe List("B", "D")
res should approxEqual(
Expand All @@ -57,10 +57,10 @@ class TensorOpsContractionSuite extends DimwitTest:
)
)

it("~ should respect position-aware mapping in types"):
it("Axis mapping should respect position-aware mapping in types"):
val mCD = m2.relabelAll((Axis[C], Axis[D]))
"m1.dot(Axis[A ~ C])(mCD)" should compile
"m1.dot(Axis[C ~ A])(mCD)" shouldNot compile
"m1.dot(Axis[A] -> Axis[C])(mCD)" should compile
"m1.dot(Axis[C] -> Axis[A])(mCD)" shouldNot compile

describe("outerProduct"):
it("Tensor1[A] and Tensor1[B] to Tensor2[A, B]"):
Expand Down
Loading