From c7cfd1e412b71b5f7b768399f4513c1b5efb308c Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 23 Jun 2026 06:00:16 +0200 Subject: [PATCH] change dot on to accept pair of axis values instead of special axis type --- core/src/main/scala/dimwit/package.scala | 6 ------ core/src/main/scala/dimwit/tensor/TensorOps.scala | 4 ++-- .../dimwit/tensor/TensorOpsContractionSuite.scala | 10 +++++----- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index e2594438..3a9fef85 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -31,12 +31,6 @@ package object dimwit: System.gc() Jax.gc() - @targetName("On") - infix trait ~[A, B] - object `~`: - given [A, B](using labelA: Label[A], labelB: Label[B]): Label[A ~ B] with - val name: String = s"${labelA.name}_on_${labelB.name}" - /** Combination of dimensions / labels * * Mentally think of this as the "product" of two dimensions. diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index d16d6656..9a9e6950 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -10,7 +10,7 @@ import dimwit.tensor.{Label, Labels} import dimwit.tensor.ShapeTypeHelpers.* import dimwit.tensor.TensorOps.Functional.ZipVmap.{ShapesOf, TensorsOf} import dimwit.tensor.TupleHelpers.* -import dimwit.{~, `|*|`, `|+|`} +import dimwit.{`|*|`, `|+|`} import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters @@ -312,7 +312,7 @@ object TensorOps: ContractAxisA, ContractAxisB, OtherShape <: Tuple - ](axis: Axis[ContractAxisA ~ ContractAxisB])(other: Tensor[OtherShape, V])(using + ](axisPair: (Axis[ContractAxisA], Axis[ContractAxisB]))(other: Tensor[OtherShape, V])(using ev: AxisRemover[T, ContractAxisA], evOther: AxisRemover[OtherShape, ContractAxisB] )(using diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala index 89f33e57..262337f9 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala @@ -45,10 +45,10 @@ class TensorOpsContractionSuite extends DimwitTest: ) describe("dot on different axis labels (A1 ~ A2)"): - it("Tensor2[A, B] and Tensor2[C, D] using Axis mapping (A ~ C)"): + it("Tensor2[A, B] and Tensor2[C, D] using Axis mapping (Axis[A] -> Axis[C])"): val mCD = m2.relabelAll((Axis[C], Axis[D])) - val res = m1.dot(Axis[A ~ C])(mCD) + val res = m1.dot(Axis[A] -> Axis[C])(mCD) res.shape.labels shouldBe List("B", "D") res should approxEqual( @@ -57,10 +57,10 @@ class TensorOpsContractionSuite extends DimwitTest: ) ) - it("~ should respect position-aware mapping in types"): + it("Axis mapping should respect position-aware mapping in types"): val mCD = m2.relabelAll((Axis[C], Axis[D])) - "m1.dot(Axis[A ~ C])(mCD)" should compile - "m1.dot(Axis[C ~ A])(mCD)" shouldNot compile + "m1.dot(Axis[A] -> Axis[C])(mCD)" should compile + "m1.dot(Axis[C] -> Axis[A])(mCD)" shouldNot compile describe("outerProduct"): it("Tensor1[A] and Tensor1[B] to Tensor2[A, B]"):