From d9c144d705f107c4a35ca0b9a937025dd85fcdcb Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Thu, 25 Jun 2026 08:54:28 +0200 Subject: [PATCH 1/5] enable scalafix --- build.sbt | 9 +++++++++ project/.scalafix.conf | 21 +++++++++++++++++++++ project/plugins.sbt | 1 + 3 files changed, 31 insertions(+) create mode 100644 project/.scalafix.conf diff --git a/build.sbt b/build.sbt index fc2f4ec6..f44e43d6 100644 --- a/build.sbt +++ b/build.sbt @@ -34,6 +34,15 @@ ThisBuild / developers := List( ) ) +// Setup for Scalafix and SemanticDB +inThisBuild(Seq( + semanticdbEnabled := true, + semanticdbVersion := scalafixSemanticdb.revision +)) + +ThisBuild / scalafixDependencies += + "com.github.liancheng" %% "organize-imports" % "0.6.0" + addCommandAlias("testAndCoverage", "; clean; coverage; test; coverageReport") lazy val root = (project in file(".")) diff --git a/project/.scalafix.conf b/project/.scalafix.conf new file mode 100644 index 00000000..f184450f --- /dev/null +++ b/project/.scalafix.conf @@ -0,0 +1,21 @@ +rules = [ + OrganizeImports +] + +OrganizeImports { + groups = [ + "re:java\\..*" + "re:scala\\..*" + "re:javax\\..*" + "re:org\\.scala-lang\\..*" + "*" + ] + + importSelectorsOrder = Ascii + removeUnused = true + targetDialect = Scala3 + blankLines = Auto + expandRelative = false + groupSeparately = [ByTypeGivens] + groupedImports = Merge +} diff --git a/project/plugins.sbt b/project/plugins.sbt index 47b4a6e5..a1e890ba 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,4 +3,5 @@ addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.8.2") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.3.0") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.11.3") addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.3.1") +addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.14.7") libraryDependencies += "ai.kien" %% "python-native-libs" % "0.2.2" From 5f580bf364d094e143777f10dd31cfd1833bc3f2 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Sat, 27 Jun 2026 21:46:26 +0200 Subject: [PATCH 2/5] avoid cyclic reference by using private impl of withLocalCleanup withLocalCleanup gave a cyclic import error after cleaning up imports, showing that there was something fishy before. The error can be avoided by declaring a function directly in the package object, which references an impl function. --- core/src/main/scala/dimwit/MemoryHelper.scala | 6 ++--- core/src/main/scala/dimwit/package.scala | 23 +++++++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/dimwit/MemoryHelper.scala b/core/src/main/scala/dimwit/MemoryHelper.scala index f59f3e63..ac7604a9 100644 --- a/core/src/main/scala/dimwit/MemoryHelper.scala +++ b/core/src/main/scala/dimwit/MemoryHelper.scala @@ -1,15 +1,15 @@ package dimwit -import me.shadaj.scalapy.py import dimwit.autodiff.TensorTree +import me.shadaj.scalapy.py private[dimwit] object MemoryHelper: - def withLocalCleanup(f: => Unit): Unit = + private[dimwit] def withLocalCleanupImpl(f: => Unit): Unit = py.local: f - def withLocalCleanup[A: TensorTree](f: => A): A = + private[dimwit] def withLocalCleanupImpl[A: TensorTree](f: => A): A = val lifeRaft = me.shadaj.scalapy.py.Dynamic.global.list() py.local: val res = f diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 168f26b8..5242a5a1 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -1,7 +1,13 @@ -import scala.annotation.targetName - import dimwit.jax.Jax -import dimwit.tensor.{Axis, AxisExtent, AxisSelector, AxisAtIndex, AxisAtRange, AxisAtIndices, AxisAtTensorIndex} +import dimwit.tensor.Axis +import dimwit.tensor.AxisAtIndex +import dimwit.tensor.AxisAtIndices +import dimwit.tensor.AxisAtRange +import dimwit.tensor.AxisAtTensorIndex +import dimwit.tensor.AxisExtent +import dimwit.tensor.AxisSelector + +import scala.annotation.targetName package object dimwit: import scala.compiletime.ops.string.+ @@ -93,7 +99,16 @@ package object dimwit: // export some stats types export dimwit.stats.{Prob, LogProb} export dimwit.stats.{Distribution, IndependentDistribution, MultivariateDistribution, UnivariateDistribution} - export dimwit.MemoryHelper.withLocalCleanup + + /** Memory management helpfer making sure + * all python objects allocated ar freed + * after the function is executed. + */ + def withLocalCleanup(f: => Unit): Unit = + MemoryHelper.withLocalCleanupImpl(f) + + def withLocalCleanup[A: TensorTree](f: => A): A = + MemoryHelper.withLocalCleanupImpl(f) /** Explicitly configures the Python environment before any ScalaPy call. * Call this function at the start of your program (before any `py.*` call) From eb9d5d31034aba888ef60f2f454451ecefb9ec25 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Sat, 27 Jun 2026 21:47:50 +0200 Subject: [PATCH 3/5] organize imports and remove unused imports --- core/src/main/scala/dimwit/OnError.scala | 2 +- .../main/scala/dimwit/autodiff/Autodiff.scala | 9 ++- .../scala/dimwit/autodiff/FloatTree.scala | 8 +-- .../src/main/scala/dimwit/autodiff/Grad.scala | 3 +- .../scala/dimwit/autodiff/TensorTree.scala | 6 +- .../main/scala/dimwit/hardware/Device.scala | 7 +- .../scala/dimwit/hardware/DeviceBackend.scala | 2 +- core/src/main/scala/dimwit/jax/Einops.scala | 1 - core/src/main/scala/dimwit/jax/Jax.scala | 7 +- core/src/main/scala/dimwit/jax/JaxDType.scala | 2 +- core/src/main/scala/dimwit/jax/Jit.scala | 11 ++- .../scala/dimwit/nn/ActivationFunctions.scala | 7 +- .../dimwit/optimizer/GradientOptimizer.scala | 10 ++- .../main/scala/dimwit/python/PyBridge.scala | 6 +- .../scala/dimwit/python/PythonSetup.scala | 4 +- .../src/main/scala/dimwit/random/Random.scala | 13 ++-- .../scala/dimwit/stats/Distributions.scala | 5 +- .../stats/IndependentDistributions.scala | 9 ++- .../stats/MultivariateDistributions.scala | 8 +-- .../stats/UnivariateDistributions.scala | 6 +- .../scala/dimwit/tensor/ArrayReader.scala | 3 +- .../scala/dimwit/tensor/ArrayWriter.scala | 12 ++-- core/src/main/scala/dimwit/tensor/Axis.scala | 3 +- core/src/main/scala/dimwit/tensor/DType.scala | 9 ++- .../main/scala/dimwit/tensor/HasScalar.scala | 3 +- .../src/main/scala/dimwit/tensor/Labels.scala | 5 +- core/src/main/scala/dimwit/tensor/Shape.scala | 6 +- .../dimwit/tensor/ShapeTypeHelpers.scala | 3 +- .../src/main/scala/dimwit/tensor/Tensor.scala | 25 +++---- .../main/scala/dimwit/tensor/TensorOps.scala | 22 +----- .../scala/dimwit/tensor/TupleHelpers.scala | 15 ++-- core/src/main/scala/dimwit/tensor/VType.scala | 4 -- .../main/scala/dimwit/tensor/ValueOps.scala | 6 +- .../tensor/tensorops/ContractionOps.scala | 26 ++----- .../tensor/tensorops/ConvolutionOps.scala | 27 ++----- .../tensor/tensorops/ElementWiseOps.scala | 12 ++-- .../tensor/tensorops/FunctionalOps.scala | 35 +++------- .../tensor/tensorops/LinearAlgebraOps.scala | 28 +++----- .../tensor/tensorops/ReductionOps.scala | 25 +++---- .../tensor/tensorops/StructuralOps.scala | 70 ++++++++----------- .../dimwit/tensor/tensorops/Tensor0Ops.scala | 12 +--- .../dimwit/tensor/tensorops/Tensor1Ops.scala | 23 ++---- .../dimwit/tensor/tensorops/Tensor2Ops.scala | 21 +----- .../dimwit/tensor/tensorops/Tensor3Ops.scala | 21 ------ .../tensor/tensorops/TensorOpsUtils.scala | 3 +- examples/src/main/scala/basic/KMeans.scala | 2 +- .../main/scala/basic/LogisticRegression.scala | 7 +- .../src/main/scala/basic/SIRSimulation.scala | 4 +- .../complex/VariationalAutoencoder.scala | 23 +++--- .../src/main/scala/dataset/MNISTLoader.scala | 5 +- 50 files changed, 214 insertions(+), 372 deletions(-) diff --git a/core/src/main/scala/dimwit/OnError.scala b/core/src/main/scala/dimwit/OnError.scala index 5060f042..9d06a575 100644 --- a/core/src/main/scala/dimwit/OnError.scala +++ b/core/src/main/scala/dimwit/OnError.scala @@ -1,7 +1,7 @@ package dimwit -import java.io.StringWriter import java.io.PrintWriter +import java.io.StringWriter object OnError: diff --git a/core/src/main/scala/dimwit/autodiff/Autodiff.scala b/core/src/main/scala/dimwit/autodiff/Autodiff.scala index b26028da..871733c1 100644 --- a/core/src/main/scala/dimwit/autodiff/Autodiff.scala +++ b/core/src/main/scala/dimwit/autodiff/Autodiff.scala @@ -1,13 +1,12 @@ package dimwit.autodiff import dimwit.OnError -import dimwit.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape} -import dimwit.tensor.ShapeTypeHelpers.AxisIndices -import dimwit.tensor.TupleHelpers.PrimeConcatType import dimwit.jax.Jax -import me.shadaj.scalapy.py -import dimwit.tensor.Label +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TupleHelpers.PrimeConcatType +import me.shadaj.scalapy.py object Autodiff: diff --git a/core/src/main/scala/dimwit/autodiff/FloatTree.scala b/core/src/main/scala/dimwit/autodiff/FloatTree.scala index 32128c15..2d0187d6 100644 --- a/core/src/main/scala/dimwit/autodiff/FloatTree.scala +++ b/core/src/main/scala/dimwit/autodiff/FloatTree.scala @@ -1,9 +1,9 @@ package dimwit.autodiff -import dimwit.tensor.* -import dimwit.tensor.TensorOps.* -import scala.deriving.* -import scala.compiletime.* +import dimwit.tensor.TensorOps._ +import dimwit.tensor._ + +import scala.deriving._ import scala.util.NotGiven /** A marker trait for structures that are trees of floating-point tensors. diff --git a/core/src/main/scala/dimwit/autodiff/Grad.scala b/core/src/main/scala/dimwit/autodiff/Grad.scala index aec27598..83dce811 100644 --- a/core/src/main/scala/dimwit/autodiff/Grad.scala +++ b/core/src/main/scala/dimwit/autodiff/Grad.scala @@ -1,7 +1,8 @@ package dimwit.autodiff -import dimwit.* +import dimwit._ import dimwit.jax.Jax + import scala.deriving.Mirror /** Type-level tag marking a parameter structure as gradients. diff --git a/core/src/main/scala/dimwit/autodiff/TensorTree.scala b/core/src/main/scala/dimwit/autodiff/TensorTree.scala index 3e403ee0..90fcfb81 100644 --- a/core/src/main/scala/dimwit/autodiff/TensorTree.scala +++ b/core/src/main/scala/dimwit/autodiff/TensorTree.scala @@ -1,12 +1,12 @@ package dimwit.autodiff +import dimwit.jax.Jax import dimwit.tensor.* -import dimwit.random.Random -import scala.deriving.* import scala.compiletime.* +import scala.deriving.* + import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.jax.Jax /** A typeclass for structures that can be represented as a tree of tensors, * which can be mapped over. Most often, a tensor tree is used to structure diff --git a/core/src/main/scala/dimwit/hardware/Device.scala b/core/src/main/scala/dimwit/hardware/Device.scala index caf90085..f8e7e7bf 100644 --- a/core/src/main/scala/dimwit/hardware/Device.scala +++ b/core/src/main/scala/dimwit/hardware/Device.scala @@ -1,12 +1,7 @@ package dimwit.hardware -import dimwit.* -import dimwit.tensor.TupleHelpers -import dimwit.jax.Jax +import dimwit._ import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.Writer -import me.shadaj.scalapy.interpreter.PyValue case class Device private[dimwit] (private[dimwit] val jaxDevice: py.Dynamic): diff --git a/core/src/main/scala/dimwit/hardware/DeviceBackend.scala b/core/src/main/scala/dimwit/hardware/DeviceBackend.scala index 0e6bcbc0..0aa5715a 100644 --- a/core/src/main/scala/dimwit/hardware/DeviceBackend.scala +++ b/core/src/main/scala/dimwit/hardware/DeviceBackend.scala @@ -1,7 +1,7 @@ package dimwit.hardware -import me.shadaj.scalapy.py import dimwit.jax.Jax +import me.shadaj.scalapy.py enum DeviceBackend(private[dimwit] val jaxName: String): diff --git a/core/src/main/scala/dimwit/jax/Einops.scala b/core/src/main/scala/dimwit/jax/Einops.scala index 5043a8e7..18d8b607 100644 --- a/core/src/main/scala/dimwit/jax/Einops.scala +++ b/core/src/main/scala/dimwit/jax/Einops.scala @@ -2,7 +2,6 @@ package dimwit.jax import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.py.PyQuote object Einops: diff --git a/core/src/main/scala/dimwit/jax/Jax.scala b/core/src/main/scala/dimwit/jax/Jax.scala index ead7d4a6..e50e04d7 100644 --- a/core/src/main/scala/dimwit/jax/Jax.scala +++ b/core/src/main/scala/dimwit/jax/Jax.scala @@ -1,10 +1,9 @@ package dimwit.jax -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.py.PyQuote -import dimwit.hardware.{Device, DeviceBackend} +import dimwit.hardware.Device +import dimwit.hardware.DeviceBackend import dimwit.python.PythonSetup +import me.shadaj.scalapy.py object Jax: diff --git a/core/src/main/scala/dimwit/jax/JaxDType.scala b/core/src/main/scala/dimwit/jax/JaxDType.scala index 34d9ee0a..ffefedf7 100644 --- a/core/src/main/scala/dimwit/jax/JaxDType.scala +++ b/core/src/main/scala/dimwit/jax/JaxDType.scala @@ -1,7 +1,7 @@ package dimwit.jax -import me.shadaj.scalapy.py import dimwit.tensor.DType +import me.shadaj.scalapy.py object JaxDType: diff --git a/core/src/main/scala/dimwit/jax/Jit.scala b/core/src/main/scala/dimwit/jax/Jit.scala index 665b335f..c10c4293 100644 --- a/core/src/main/scala/dimwit/jax/Jit.scala +++ b/core/src/main/scala/dimwit/jax/Jit.scala @@ -1,13 +1,12 @@ package dimwit.jax -import dimwit.tensor.{Tensor, Shape, Labels} -import dimwit.jax.{Jax, JaxDType} +import dimwit.OnError import dimwit.autodiff.TensorTree +import dimwit.jax.Jax +import dimwit.jax.Jax.PyDynamic import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.jax.Jax.PyDynamic -import me.shadaj.scalapy.py.PythonException -import dimwit.OnError + import scala.annotation.targetName object Jit: @@ -142,7 +141,7 @@ object JitDefault: object EagerCleanup: - import dimwit.MemoryHelper.withLocalCleanup + import dimwit.withLocalCleanup def eagerCleanup[T1, R: TensorTree](f: T1 => R): T1 => R = (t1) => withLocalCleanup: diff --git a/core/src/main/scala/dimwit/nn/ActivationFunctions.scala b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala index 542419f4..bf67152e 100644 --- a/core/src/main/scala/dimwit/nn/ActivationFunctions.scala +++ b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala @@ -1,9 +1,10 @@ package dimwit.nn -import dimwit.tensor.* -import dimwit.tensor.TensorOps.IsFloating import dimwit.jax.Jax -import dimwit.python.PyBridge.{liftPyTensor, toPyTensor} +import dimwit.python.PyBridge.liftPyTensor +import dimwit.python.PyBridge.toPyTensor +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor._ object ActivationFunctions: diff --git a/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala index 2ffaa87e..2f9d6a4b 100644 --- a/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala +++ b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala @@ -1,11 +1,9 @@ package dimwit.optimizer -import dimwit.* -import dimwit.autodiff.FloatTree.ops.* -import dimwit.autodiff.FloatTree.* -import dimwit.autodiff.* -import dimwit.jax.Jax -import dimwit.jax.Jit +import dimwit._ +import dimwit.autodiff.FloatTree._ +import dimwit.autodiff.FloatTree.ops._ +import dimwit.autodiff._ /** Gradient optimizer interface with functional state management. * diff --git a/core/src/main/scala/dimwit/python/PyBridge.scala b/core/src/main/scala/dimwit/python/PyBridge.scala index 2a9ce942..89334e75 100644 --- a/core/src/main/scala/dimwit/python/PyBridge.scala +++ b/core/src/main/scala/dimwit/python/PyBridge.scala @@ -1,9 +1,9 @@ package dimwit.python -import dimwit.tensor.* -import dimwit.jax.Jax -import dimwit.autodiff.TensorTree import dimwit.OnError +import dimwit.autodiff.TensorTree +import dimwit.jax.Jax +import dimwit.tensor._ import me.shadaj.scalapy.py object PyBridge: diff --git a/core/src/main/scala/dimwit/python/PythonSetup.scala b/core/src/main/scala/dimwit/python/PythonSetup.scala index 1eee00a3..cf19e6c6 100644 --- a/core/src/main/scala/dimwit/python/PythonSetup.scala +++ b/core/src/main/scala/dimwit/python/PythonSetup.scala @@ -1,7 +1,9 @@ package dimwit.python import me.shadaj.scalapy.py -import scala.sys.process.{Process, ProcessLogger} + +import scala.sys.process.Process +import scala.sys.process.ProcessLogger /** Manages Python environment setup for DimWit. * diff --git a/core/src/main/scala/dimwit/random/Random.scala b/core/src/main/scala/dimwit/random/Random.scala index 492c22d5..183bf8ea 100644 --- a/core/src/main/scala/dimwit/random/Random.scala +++ b/core/src/main/scala/dimwit/random/Random.scala @@ -1,15 +1,14 @@ package dimwit.random -import dimwit.tensor.* -import dimwit.tensor.DType.Int32 -import dimwit.tensor.TensorOps.* -import dimwit.jax.{Jax, JaxDType} import dimwit.autodiff.TensorTree -import me.shadaj.scalapy.py.SeqConverters +import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor -import scala.compiletime.{requireConst, constValue, ops} -import Tuple.Size +import dimwit.tensor.DType.Int32 +import dimwit.tensor.TensorOps._ import dimwit.tensor.TupleHelpers.TupleNOf +import dimwit.tensor._ + +import scala.compiletime.requireConst /** JAX-based random number generation with proper key management. * diff --git a/core/src/main/scala/dimwit/stats/Distributions.scala b/core/src/main/scala/dimwit/stats/Distributions.scala index a20f7090..629b181b 100644 --- a/core/src/main/scala/dimwit/stats/Distributions.scala +++ b/core/src/main/scala/dimwit/stats/Distributions.scala @@ -1,10 +1,7 @@ package dimwit.stats -import dimwit.* +import dimwit._ import dimwit.random.Random -import dimwit.jax.Jax -import dimwit.jax.Jax.scipy_stats as jstats -import dimwit.jax.Jax.PyDynamic import dimwit.tensor.TensorOps opaque type LogProb = Float32 diff --git a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala index cccdaf0c..4f73de94 100644 --- a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala +++ b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala @@ -1,14 +1,13 @@ package dimwit.stats -import dimwit.* import dimwit.DType.Float32 -import dimwit.jax.Jax.scipy_stats as jstats +import dimwit._ import dimwit.jax.Jax -import dimwit.jax.Jax.PyDynamic +import dimwit.jax.Jax.{scipy_stats => jstats} +import dimwit.python.PyBridge.liftPyTensor +import dimwit.random.Random import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.random.Random -import dimwit.python.PyBridge.liftPyTensor /** Normal (Gaussian) distribution */ class Normal[T <: Tuple: Labels, V: IsFloating](val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]: diff --git a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala index ed66bc7b..1fe8b413 100644 --- a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala @@ -1,12 +1,10 @@ package dimwit.stats -import dimwit.* -import dimwit.random.Random +import dimwit._ import dimwit.jax.Jax -import dimwit.jax.Jax.scipy_stats as jstats -import dimwit.jax.Jax.PyDynamic +import dimwit.jax.Jax.{scipy_stats => jstats} import dimwit.python.PyBridge.liftPyTensor -import me.shadaj.scalapy.readwrite.Reader +import dimwit.random.Random /** Distribution over a vector of random variables. */ diff --git a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala index 3bf66df5..68e45eed 100644 --- a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala @@ -1,8 +1,8 @@ package dimwit.stats -import dimwit.* -import dimwit.DType.{Int32, Float32} -import dimwit.random.* +import dimwit.DType.Float32 +import dimwit.DType.Int32 +import dimwit._ import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor diff --git a/core/src/main/scala/dimwit/tensor/ArrayReader.scala b/core/src/main/scala/dimwit/tensor/ArrayReader.scala index 42b0b2b2..a06b4adc 100644 --- a/core/src/main/scala/dimwit/tensor/ArrayReader.scala +++ b/core/src/main/scala/dimwit/tensor/ArrayReader.scala @@ -1,8 +1,9 @@ package dimwit.tensor +import me.shadaj.scalapy.py + import java.nio.ByteBuffer import java.nio.ByteOrder -import me.shadaj.scalapy.py /** Utility object for reading flat arrays of scalar values from JAX tensors. */ diff --git a/core/src/main/scala/dimwit/tensor/ArrayWriter.scala b/core/src/main/scala/dimwit/tensor/ArrayWriter.scala index fcd4524e..bb453fcb 100644 --- a/core/src/main/scala/dimwit/tensor/ArrayWriter.scala +++ b/core/src/main/scala/dimwit/tensor/ArrayWriter.scala @@ -1,14 +1,14 @@ package dimwit.tensor -import java.nio.ByteBuffer -import java.util.Base64 -import java.nio.ByteOrder +import dimwit.jax.Jax +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import me.shadaj.scalapy.readwrite.Writer -import me.shadaj.scalapy.interpreter.PyValue -import dimwit.jax.Jax -import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} + +import java.util.Base64 object ArrayWriter: diff --git a/core/src/main/scala/dimwit/tensor/Axis.scala b/core/src/main/scala/dimwit/tensor/Axis.scala index 597f1063..583d482e 100644 --- a/core/src/main/scala/dimwit/tensor/Axis.scala +++ b/core/src/main/scala/dimwit/tensor/Axis.scala @@ -1,9 +1,8 @@ package dimwit.tensor -import dimwit.|*| import dimwit.tensor.DType.Int32 +import dimwit.|*| -import scala.compiletime.{constValue, erasedValue, summonInline} import ShapeTypeHelpers.AxisIndex /** Instances of this class represent an axis in a tensor with a specific label `L`. diff --git a/core/src/main/scala/dimwit/tensor/DType.scala b/core/src/main/scala/dimwit/tensor/DType.scala index 8365a456..d0f26f0d 100644 --- a/core/src/main/scala/dimwit/tensor/DType.scala +++ b/core/src/main/scala/dimwit/tensor/DType.scala @@ -1,10 +1,13 @@ package dimwit.tensor import dimwit.jax.JaxDType +import dimwit.tensor.HasScalar +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger +import me.shadaj.scalapy.py + import java.nio.ByteBuffer import java.nio.ByteOrder -import me.shadaj.scalapy.py -import dimwit.tensor.TensorOps.{IsFloating, IsInteger, IsBoolean} -import dimwit.tensor.HasScalar object DType: diff --git a/core/src/main/scala/dimwit/tensor/HasScalar.scala b/core/src/main/scala/dimwit/tensor/HasScalar.scala index 727f4f2f..66d89202 100644 --- a/core/src/main/scala/dimwit/tensor/HasScalar.scala +++ b/core/src/main/scala/dimwit/tensor/HasScalar.scala @@ -1,8 +1,9 @@ package dimwit.tensor -import scala.annotation.implicitNotFound import me.shadaj.scalapy.py +import scala.annotation.implicitNotFound + @implicitNotFound( "No Scala type mapping for DType ${V}. Supported: Bool, Int8, Int16, Int32, Int64, Float32, Float64." ) diff --git a/core/src/main/scala/dimwit/tensor/Labels.scala b/core/src/main/scala/dimwit/tensor/Labels.scala index 38902a61..a2009aad 100644 --- a/core/src/main/scala/dimwit/tensor/Labels.scala +++ b/core/src/main/scala/dimwit/tensor/Labels.scala @@ -1,9 +1,8 @@ package dimwit.tensor -import scala.compiletime.* -import scala.quoted.* +import scala.quoted._ + import Tuple.:* -import dimwit.tensor.ShapeTypeHelpers.MergeLabels @scala.annotation.implicitNotFound(""" An axis label ${T} was given or inferred, which does not have a Label instance. diff --git a/core/src/main/scala/dimwit/tensor/Shape.scala b/core/src/main/scala/dimwit/tensor/Shape.scala index dec02a9f..da1d4f09 100644 --- a/core/src/main/scala/dimwit/tensor/Shape.scala +++ b/core/src/main/scala/dimwit/tensor/Shape.scala @@ -1,9 +1,11 @@ package dimwit.tensor -import scala.collection.View.Empty +import dimwit.tensor.Label +import dimwit.tensor.Labels + import scala.annotation.publicInBinary + import ShapeTypeHelpers.AxisIndex -import dimwit.tensor.{Labels, Label} /** Represents the Shape of a tensor. Conceptually, a shape is an order list of AxisExtents, * where each AxisExtent is a label associated with a size. diff --git a/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala b/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala index 9093a667..4192921d 100644 --- a/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala +++ b/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala @@ -1,8 +1,9 @@ package dimwit.tensor import scala.annotation.implicitNotFound +import scala.compiletime.erasedValue +import scala.compiletime.summonInline import scala.util.NotGiven -import scala.compiletime.{constValue, erasedValue, summonInline} /* Helpers for tracking Tensor Shape types across various operations */ object ShapeTypeHelpers: diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index ca9dfa0b..056bda94 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -1,25 +1,26 @@ package dimwit.tensor -import scala.annotation.targetName -import scala.compiletime.{erasedValue, summonFrom} +import dimwit.Prime +import dimwit.hardware.Device import dimwit.jax.Jax -import dimwit.jax.JaxDType import dimwit.jax.Jax.PyDynamic +import dimwit.jax.JaxDType +import dimwit.tensor.Label +import dimwit.tensor.Labels +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger import dimwit.tensor.TypedIndex -import dimwit.tensor.{Label, Labels, VType} +import dimwit.tensor.VType import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.random.Random -import dimwit.stats.{Normal, Uniform} import me.shadaj.scalapy.readwrite.Writer + +import scala.annotation.targetName import scala.reflect.ClassTag -import scala.annotation.unchecked.uncheckedVariance -import dimwit.Prime + import ShapeTypeHelpers.AxisIndex -import dimwit.hardware.Device -import me.shadaj.scalapy.readwrite.Writer.stringWriter.given -import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} -import DType.* +import DType._ /** A tensor with a fixed shape and data type. * diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 0636a4a5..254eefa4 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -1,31 +1,13 @@ package dimwit.tensor -import dimwit.DType.* -import dimwit.DType.given -import dimwit.OnError -import dimwit.jax.Jax import dimwit.tensor.HasScalar import dimwit.tensor.Label import dimwit.tensor.Labels -import dimwit.tensor.ShapeTypeHelpers.* -import dimwit.tensor.TensorOps.ZipVmap.ShapesOf -import dimwit.tensor.TensorOps.ZipVmap.TensorsOf -import dimwit.tensor.TupleHelpers.* -import dimwit.tensor.tensorops.StructuralOps -import dimwit.|*| -import dimwit.|+| -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.Reader -import me.shadaj.scalapy.readwrite.Writer +import dimwit.tensor.ShapeTypeHelpers._ +import dimwit.tensor.TupleHelpers._ import scala.annotation.implicitNotFound import scala.annotation.targetName -import scala.compiletime.ops.int.<= -import scala.util.NotGiven - -import Tuple.:* -import Tuple.++ object TensorOps: diff --git a/core/src/main/scala/dimwit/tensor/TupleHelpers.scala b/core/src/main/scala/dimwit/tensor/TupleHelpers.scala index a854c6f6..f6379d6b 100644 --- a/core/src/main/scala/dimwit/tensor/TupleHelpers.scala +++ b/core/src/main/scala/dimwit/tensor/TupleHelpers.scala @@ -1,14 +1,10 @@ package dimwit.tensor -import scala.util.NotGiven -import scala.annotation.implicitNotFound - -import scala.compiletime.{constValue, error} -import scala.compiletime.ops.string.+ -import scala.quoted.Type -import scala.quoted.Quotes -import scala.quoted.Expr import scala.compiletime.ops +import scala.quoted.Expr +import scala.quoted.Quotes +import scala.quoted.Type +import scala.util.NotGiven /* Helpers for manipulating Tuple types */ object TupleHelpers: @@ -100,8 +96,6 @@ object TupleHelpers: case _ => T *: TupleNOf[ops.int.-[N, 1], T] import dimwit.Prime - import scala.compiletime.ops.boolean.* - import scala.compiletime.ops.boolean.* type Member[X, T <: Tuple] <: Boolean = T match case EmptyTuple => false @@ -193,7 +187,6 @@ object TupleHelpers: // If the result is MissingAxis[A], it fails with your message. sealed trait CheckValid[R <: ValidationResult] - import scala.compiletime.ops.any.ToString object CheckValid: // Case 1: Success. We provide an instance, so compilation proceeds. given ok: CheckValid[AllOk] = new CheckValid[AllOk] {} diff --git a/core/src/main/scala/dimwit/tensor/VType.scala b/core/src/main/scala/dimwit/tensor/VType.scala index 78782b70..80ec6fc2 100644 --- a/core/src/main/scala/dimwit/tensor/VType.scala +++ b/core/src/main/scala/dimwit/tensor/VType.scala @@ -1,9 +1,5 @@ package dimwit.tensor -import dimwit.stats.Prob -import dimwit.stats.LogProb -import scala.compiletime.ops.double -import java.nio.ByteBuffer import dimwit.tensor.TensorOps.HasDType object VType: diff --git a/core/src/main/scala/dimwit/tensor/ValueOps.scala b/core/src/main/scala/dimwit/tensor/ValueOps.scala index f5c23767..9dcbd94e 100644 --- a/core/src/main/scala/dimwit/tensor/ValueOps.scala +++ b/core/src/main/scala/dimwit/tensor/ValueOps.scala @@ -1,12 +1,12 @@ package dimwit.tensor -import dimwit.tensor.TensorOps.IsNumber import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.divide import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.ElementWiseOps.subtract import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.divide object ValueOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala index 5074592f..9193e3e6 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala @@ -1,29 +1,15 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - +import dimwit.tensor.Labels +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.Tensor +import dimwit.tensor.TupleHelpers.PrimeConcat import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.TupleHelpers.PrimeConcat +import me.shadaj.scalapy.readwrite.Writer + import scala.annotation.targetName object ContractionOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala index 9a4f84da..5f5f1fff 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala @@ -1,30 +1,17 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.AxisExtent +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - +import dimwit.tensor.Tensor +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.swap import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.AxisExtent -import dimwit.tensor.TensorOps.swap +import me.shadaj.scalapy.readwrite.Writer object ConvolutionOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala index a7ecdbdd..1a30cba8 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala @@ -1,17 +1,17 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax import dimwit.tensor.DType.Bool +import dimwit.tensor.DType.Float32 +import dimwit.tensor.DType.Int32 +import dimwit.tensor.Labels +import dimwit.tensor.Tensor import dimwit.tensor.Tensor0 import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.VType import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast object ElementWiseOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala index 39e0bff9..8e0aa041 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala @@ -1,34 +1,21 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels +import dimwit.OnError import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.ShapeTypeHelpers.SharedAxisRemover -import dimwit.OnError +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.LabelsImpl +import dimwit.tensor.ShapeTypeHelpers.AxisRemover import dimwit.tensor.ShapeTypeHelpers.AxisReplacer +import dimwit.tensor.ShapeTypeHelpers.SharedAxisRemover +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 import dimwit.tensor.tensorops.FunctionalOps.ZipVmap.TensorsOf +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.Reader +import me.shadaj.scalapy.readwrite.Writer object FunctionalOps: // ----------------------------------------------------------- diff --git a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala index d6fb6104..06163253 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala @@ -1,31 +1,19 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool +import dimwit.tensor.Axis +import dimwit.tensor.Label +import dimwit.tensor.Labels +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Tensor import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.Tensor1 +import dimwit.tensor.Tensor2 import dimwit.tensor.TensorOps.IsFloating import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover -import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.Tensor2 -import dimwit.tensor.Tensor1 +import me.shadaj.scalapy.readwrite.Writer object LinearAlgebraOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala index a529e2d4..ff3131b7 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala @@ -1,29 +1,22 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType +import dimwit.tensor.Axis import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.Labels import dimwit.tensor.ShapeTypeHelpers.AxesRemover -import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes import dimwit.tensor.ShapeTypeHelpers.AxisIndex import dimwit.tensor.ShapeTypeHelpers.AxisIndices - +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import me.shadaj.scalapy.readwrite.Writer object ReductionOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala index de0abcbd..700d0747 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala @@ -1,55 +1,47 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels +import dimwit.jax.Einops import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices -import dimwit.{`|*|`, `|+|`} - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} import dimwit.tensor.AxisAtIndex -import dimwit.tensor.AxisAtRange import dimwit.tensor.AxisAtIndices -import dimwit.tensor.AxisAtTupleIndices +import dimwit.tensor.AxisAtRange import dimwit.tensor.AxisAtTensorIndex -import scala.util.NotGiven -import scala.annotation.implicitNotFound -import dimwit.tensor.TupleHelpers -import dimwit.tensor.ShapeTypeHelpers.DimExtractor +import dimwit.tensor.AxisAtTupleIndices import dimwit.tensor.AxisExtent -import dimwit.tensor.Tensor1 -import dimwit.tensor.ShapeTypeHelpers.MergeLabels -import dimwit.tensor.ShapeTypeHelpers.AxesMerger +import dimwit.tensor.DType.Bool +import dimwit.tensor.DType.Int32 +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.Shape +import dimwit.tensor.ShapeTypeHelpers.AxesConditionalRemover +import dimwit.tensor.ShapeTypeHelpers.AxesMerger +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxisReplacer import dimwit.tensor.ShapeTypeHelpers.AxisReplacerAll +import dimwit.tensor.ShapeTypeHelpers.DimExtractor +import dimwit.tensor.ShapeTypeHelpers.MergeLabels +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.UnwrapDims +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 +import dimwit.tensor.Tensor1 +import dimwit.tensor.TupleHelpers +import dimwit.tensor.TupleHelpers.StrictSubset +import dimwit.tensor.TupleHelpers.TensorEvidence.CheckValid +import dimwit.tensor.TupleHelpers.TensorEvidence.ComputeMissing import dimwit.tensor.TupleHelpers.TensorEvidence.IsPermutation import dimwit.tensor.TupleHelpers.TensorEvidence.ValidationResult -import dimwit.tensor.ShapeTypeHelpers.AxesConditionalRemover -import dimwit.tensor.TupleHelpers.TensorEvidence.ComputeMissing -import dimwit.tensor.TupleHelpers.TensorEvidence.CheckValid -import dimwit.tensor.ShapeTypeHelpers.UnwrapDims +import dimwit.|+| +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.Reader +import me.shadaj.scalapy.readwrite.Writer -import dimwit.jax.Einops -import dimwit.tensor.TupleHelpers.StrictSubset -import dimwit.tensor.ShapeTypeHelpers.AxisReplacer +import scala.annotation.implicitNotFound +import scala.util.NotGiven // ----------------------------------------------------------- // 4. Structural Operations (Isomorphisms) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala index 5f8d3d0d..91250960 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala @@ -1,17 +1,7 @@ package dimwit.tensor.tensorops +import dimwit.tensor.DType._ import dimwit.tensor.Tensor0 -import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating object Tensor0Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala index bef2a6c3..caa69da4 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala @@ -1,26 +1,17 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 -import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating +import dimwit.jax.Jax import dimwit.tensor.Axis +import dimwit.tensor.DType._ +import dimwit.tensor.HasScalar import dimwit.tensor.Label +import dimwit.tensor.Labels +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 import dimwit.tensor.Tensor1 -import dimwit.tensor.HasScalar -import dimwit.jax.Jax - import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import me.shadaj.scalapy.readwrite.Writer object Tensor1Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala index c33e94f8..0bf95e94 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala @@ -1,26 +1,9 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 -import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating import dimwit.tensor.Axis -import dimwit.tensor.Label -import dimwit.tensor.Tensor1 import dimwit.tensor.HasScalar -import dimwit.jax.Jax - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.Tensor2 object Tensor2Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala index 222afab6..8b705168 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala @@ -1,27 +1,6 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 -import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.Axis -import dimwit.tensor.Label -import dimwit.tensor.Tensor1 import dimwit.tensor.HasScalar -import dimwit.jax.Jax - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.Tensor2 import dimwit.tensor.Tensor3 object Tensor3Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala b/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala index fa9b577b..76555e0b 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala @@ -1,10 +1,11 @@ package dimwit.tensor.tensorops -import scala.annotation.implicitNotFound import dimwit.tensor.Labels import dimwit.tensor.Tensor import dimwit.tensor.TupleHelpers.StrictSubset +import scala.annotation.implicitNotFound + object TensorOpsUtil: import dimwit.tensor.TensorOps.broadcastTo diff --git a/examples/src/main/scala/basic/KMeans.scala b/examples/src/main/scala/basic/KMeans.scala index d3b26353..2581613a 100644 --- a/examples/src/main/scala/basic/KMeans.scala +++ b/examples/src/main/scala/basic/KMeans.scala @@ -1,7 +1,7 @@ package examples.basic.kmeans -import dimwit.* import dimwit.Conversions.given +import dimwit._ import dimwit.random.Random import dimwit.stats.Normal diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 06a69a0f..747eace1 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -1,10 +1,11 @@ package examples.basic -import dimwit.* import dimwit.Conversions.given -import dimwit.autodiff.* +import dimwit._ +import dimwit.autodiff._ +import dimwit.nn.ActivationFunctions.relu +import dimwit.nn.ActivationFunctions.sigmoid import dimwit.optimizer.GradientDescent -import dimwit.nn.ActivationFunctions.{sigmoid, relu} import dimwit.random.Random import dimwit.stats.Normal diff --git a/examples/src/main/scala/basic/SIRSimulation.scala b/examples/src/main/scala/basic/SIRSimulation.scala index 8ccf4e34..5f638c90 100644 --- a/examples/src/main/scala/basic/SIRSimulation.scala +++ b/examples/src/main/scala/basic/SIRSimulation.scala @@ -1,8 +1,6 @@ package src.main.scala.basic -import dimwit.* -import dimwit.Conversions.given -import dimwit.autodiff.* +import dimwit._ /** A simple SIR (Susceptible-Infectious-Recovered) simulation. */ diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index f42fa65d..c56c92c8 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -1,20 +1,23 @@ package examples.complex.vae -import dimwit.* import dimwit.Conversions.given -import dimwit.autodiff.* -import dimwit.autodiff.FloatTree.* -import dimwit.stats.Normal -import dimwit.random.Random -import examples.dataset.MNISTLoader +import dimwit._ +import dimwit.autodiff.FloatTree._ +import dimwit.autodiff._ import dimwit.nn.ActivationFunctions.relu -import dimwit.optimizer.GradientDescent -import dimwit.jax.Jax import dimwit.nn.ActivationFunctions.sigmoid +import dimwit.optimizer.GradientDescent +import dimwit.python.PyBridge.toPyTensor +import dimwit.random.Random import dimwit.random.Random.Key +import dimwit.stats.Normal +import examples.dataset.MNISTLoader -import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width} -import dimwit.python.PyBridge.toPyTensor +import MNISTLoader.Sample +import MNISTLoader.TrainSample +import MNISTLoader.TestSample +import MNISTLoader.Height +import MNISTLoader.Width type Pixel = Height |*| Width type ReconstructedPixel = Height |*| Width diff --git a/examples/src/main/scala/dataset/MNISTLoader.scala b/examples/src/main/scala/dataset/MNISTLoader.scala index 2fc0112e..e0c2d6af 100644 --- a/examples/src/main/scala/dataset/MNISTLoader.scala +++ b/examples/src/main/scala/dataset/MNISTLoader.scala @@ -1,10 +1,9 @@ package examples.dataset -import dimwit.* import dimwit.Conversions.given - +import dimwit._ import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters + import java.io.RandomAccessFile import scala.util.Try From a1c0d021c05746862040ed51c9a34835b53ba3d2 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Sat, 27 Jun 2026 21:48:15 +0200 Subject: [PATCH 4/5] add some scalac flags --- build.sbt | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index f44e43d6..383f472d 100644 --- a/build.sbt +++ b/build.sbt @@ -34,6 +34,13 @@ ThisBuild / developers := List( ) ) +lazy val commonScalacOptions = Seq( + "-deprecation", + "-unchecked", + "-Wunused:imports", + "-explain-cyclic" +) + // Setup for Scalafix and SemanticDB inThisBuild(Seq( semanticdbEnabled := true, @@ -75,7 +82,8 @@ lazy val core = (project in file("core")) coverageFailOnMinimum := false, coverageHighlighting := true, Compile / packageSrc / publishArtifact := true, - Compile / packageDoc / publishArtifact := true + Compile / packageDoc / publishArtifact := true, + scalacOptions ++= commonScalacOptions ) // Examples subproject @@ -96,6 +104,7 @@ lazy val examples = (project in file("examples")) // Examples source directory Compile / scalaSource := baseDirectory.value, Compile / resourceDirectory := baseDirectory.value / "src" / "main" / "resources", + scalacOptions ++= commonScalacOptions, scalafmtFailOnErrors := false, javaOptions ++= { if (sys.props("os.name").toLowerCase.contains("mac")) { From 7d1087ee665092df264c5620254d365eb0d5041b Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Sat, 27 Jun 2026 21:54:40 +0200 Subject: [PATCH 5/5] rewrite _ to * in import --- core/src/main/scala/dimwit/autodiff/FloatTree.scala | 6 +++--- core/src/main/scala/dimwit/autodiff/Grad.scala | 2 +- core/src/main/scala/dimwit/hardware/Device.scala | 2 +- core/src/main/scala/dimwit/nn/ActivationFunctions.scala | 2 +- .../main/scala/dimwit/optimizer/GradientOptimizer.scala | 8 ++++---- core/src/main/scala/dimwit/python/PyBridge.scala | 2 +- core/src/main/scala/dimwit/random/Random.scala | 4 ++-- core/src/main/scala/dimwit/stats/Distributions.scala | 2 +- .../scala/dimwit/stats/IndependentDistributions.scala | 4 ++-- .../scala/dimwit/stats/MultivariateDistributions.scala | 4 ++-- .../main/scala/dimwit/stats/UnivariateDistributions.scala | 2 +- core/src/main/scala/dimwit/tensor/Labels.scala | 2 +- core/src/main/scala/dimwit/tensor/Tensor.scala | 2 +- core/src/main/scala/dimwit/tensor/TensorOps.scala | 4 ++-- .../main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala | 2 +- .../main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala | 2 +- examples/src/main/scala/basic/KMeans.scala | 2 +- examples/src/main/scala/basic/LogisticRegression.scala | 4 ++-- examples/src/main/scala/basic/SIRSimulation.scala | 2 +- .../src/main/scala/complex/VariationalAutoencoder.scala | 6 +++--- examples/src/main/scala/dataset/MNISTLoader.scala | 2 +- 21 files changed, 33 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/dimwit/autodiff/FloatTree.scala b/core/src/main/scala/dimwit/autodiff/FloatTree.scala index 2d0187d6..bf0b1cba 100644 --- a/core/src/main/scala/dimwit/autodiff/FloatTree.scala +++ b/core/src/main/scala/dimwit/autodiff/FloatTree.scala @@ -1,9 +1,9 @@ package dimwit.autodiff -import dimwit.tensor.TensorOps._ -import dimwit.tensor._ +import dimwit.tensor.TensorOps.* +import dimwit.tensor.* -import scala.deriving._ +import scala.deriving.* import scala.util.NotGiven /** A marker trait for structures that are trees of floating-point tensors. diff --git a/core/src/main/scala/dimwit/autodiff/Grad.scala b/core/src/main/scala/dimwit/autodiff/Grad.scala index 83dce811..27f99640 100644 --- a/core/src/main/scala/dimwit/autodiff/Grad.scala +++ b/core/src/main/scala/dimwit/autodiff/Grad.scala @@ -1,6 +1,6 @@ package dimwit.autodiff -import dimwit._ +import dimwit.* import dimwit.jax.Jax import scala.deriving.Mirror diff --git a/core/src/main/scala/dimwit/hardware/Device.scala b/core/src/main/scala/dimwit/hardware/Device.scala index f8e7e7bf..69a01e79 100644 --- a/core/src/main/scala/dimwit/hardware/Device.scala +++ b/core/src/main/scala/dimwit/hardware/Device.scala @@ -1,6 +1,6 @@ package dimwit.hardware -import dimwit._ +import dimwit.* import me.shadaj.scalapy.py case class Device private[dimwit] (private[dimwit] val jaxDevice: py.Dynamic): diff --git a/core/src/main/scala/dimwit/nn/ActivationFunctions.scala b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala index bf67152e..ba344274 100644 --- a/core/src/main/scala/dimwit/nn/ActivationFunctions.scala +++ b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala @@ -4,7 +4,7 @@ import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor import dimwit.python.PyBridge.toPyTensor import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor._ +import dimwit.tensor.* object ActivationFunctions: diff --git a/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala index 2f9d6a4b..17a2d535 100644 --- a/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala +++ b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala @@ -1,9 +1,9 @@ package dimwit.optimizer -import dimwit._ -import dimwit.autodiff.FloatTree._ -import dimwit.autodiff.FloatTree.ops._ -import dimwit.autodiff._ +import dimwit.* +import dimwit.autodiff.FloatTree.* +import dimwit.autodiff.FloatTree.ops.* +import dimwit.autodiff.* /** Gradient optimizer interface with functional state management. * diff --git a/core/src/main/scala/dimwit/python/PyBridge.scala b/core/src/main/scala/dimwit/python/PyBridge.scala index 89334e75..d3da1a7a 100644 --- a/core/src/main/scala/dimwit/python/PyBridge.scala +++ b/core/src/main/scala/dimwit/python/PyBridge.scala @@ -3,7 +3,7 @@ package dimwit.python import dimwit.OnError import dimwit.autodiff.TensorTree import dimwit.jax.Jax -import dimwit.tensor._ +import dimwit.tensor.* import me.shadaj.scalapy.py object PyBridge: diff --git a/core/src/main/scala/dimwit/random/Random.scala b/core/src/main/scala/dimwit/random/Random.scala index 183bf8ea..d6acbf0a 100644 --- a/core/src/main/scala/dimwit/random/Random.scala +++ b/core/src/main/scala/dimwit/random/Random.scala @@ -4,9 +4,9 @@ import dimwit.autodiff.TensorTree import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor import dimwit.tensor.DType.Int32 -import dimwit.tensor.TensorOps._ +import dimwit.tensor.TensorOps.* import dimwit.tensor.TupleHelpers.TupleNOf -import dimwit.tensor._ +import dimwit.tensor.* import scala.compiletime.requireConst diff --git a/core/src/main/scala/dimwit/stats/Distributions.scala b/core/src/main/scala/dimwit/stats/Distributions.scala index 629b181b..8efb99d1 100644 --- a/core/src/main/scala/dimwit/stats/Distributions.scala +++ b/core/src/main/scala/dimwit/stats/Distributions.scala @@ -1,6 +1,6 @@ package dimwit.stats -import dimwit._ +import dimwit.* import dimwit.random.Random import dimwit.tensor.TensorOps diff --git a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala index 4f73de94..975ed4a7 100644 --- a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala +++ b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala @@ -1,9 +1,9 @@ package dimwit.stats import dimwit.DType.Float32 -import dimwit._ +import dimwit.* import dimwit.jax.Jax -import dimwit.jax.Jax.{scipy_stats => jstats} +import dimwit.jax.Jax.scipy_stats as jstats import dimwit.python.PyBridge.liftPyTensor import dimwit.random.Random import me.shadaj.scalapy.py diff --git a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala index 1fe8b413..1c428222 100644 --- a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala @@ -1,8 +1,8 @@ package dimwit.stats -import dimwit._ +import dimwit.* import dimwit.jax.Jax -import dimwit.jax.Jax.{scipy_stats => jstats} +import dimwit.jax.Jax.scipy_stats as jstats import dimwit.python.PyBridge.liftPyTensor import dimwit.random.Random diff --git a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala index 68e45eed..06c5b25b 100644 --- a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala @@ -2,7 +2,7 @@ package dimwit.stats import dimwit.DType.Float32 import dimwit.DType.Int32 -import dimwit._ +import dimwit.* import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor diff --git a/core/src/main/scala/dimwit/tensor/Labels.scala b/core/src/main/scala/dimwit/tensor/Labels.scala index a2009aad..d565c3c7 100644 --- a/core/src/main/scala/dimwit/tensor/Labels.scala +++ b/core/src/main/scala/dimwit/tensor/Labels.scala @@ -1,6 +1,6 @@ package dimwit.tensor -import scala.quoted._ +import scala.quoted.* import Tuple.:* diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index 056bda94..be29a81c 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -20,7 +20,7 @@ import scala.annotation.targetName import scala.reflect.ClassTag import ShapeTypeHelpers.AxisIndex -import DType._ +import DType.* /** A tensor with a fixed shape and data type. * diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 254eefa4..3f7dc8d2 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -3,8 +3,8 @@ package dimwit.tensor import dimwit.tensor.HasScalar import dimwit.tensor.Label import dimwit.tensor.Labels -import dimwit.tensor.ShapeTypeHelpers._ -import dimwit.tensor.TupleHelpers._ +import dimwit.tensor.ShapeTypeHelpers.* +import dimwit.tensor.TupleHelpers.* import scala.annotation.implicitNotFound import scala.annotation.targetName diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala index 91250960..32644db0 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala @@ -1,6 +1,6 @@ package dimwit.tensor.tensorops -import dimwit.tensor.DType._ +import dimwit.tensor.DType.* import dimwit.tensor.Tensor0 object Tensor0Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala index caa69da4..98b20457 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala @@ -2,7 +2,7 @@ package dimwit.tensor.tensorops import dimwit.jax.Jax import dimwit.tensor.Axis -import dimwit.tensor.DType._ +import dimwit.tensor.DType.* import dimwit.tensor.HasScalar import dimwit.tensor.Label import dimwit.tensor.Labels diff --git a/examples/src/main/scala/basic/KMeans.scala b/examples/src/main/scala/basic/KMeans.scala index 2581613a..0d129c7b 100644 --- a/examples/src/main/scala/basic/KMeans.scala +++ b/examples/src/main/scala/basic/KMeans.scala @@ -1,7 +1,7 @@ package examples.basic.kmeans import dimwit.Conversions.given -import dimwit._ +import dimwit.* import dimwit.random.Random import dimwit.stats.Normal diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 747eace1..1b35cf1c 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -1,8 +1,8 @@ package examples.basic import dimwit.Conversions.given -import dimwit._ -import dimwit.autodiff._ +import dimwit.* +import dimwit.autodiff.* import dimwit.nn.ActivationFunctions.relu import dimwit.nn.ActivationFunctions.sigmoid import dimwit.optimizer.GradientDescent diff --git a/examples/src/main/scala/basic/SIRSimulation.scala b/examples/src/main/scala/basic/SIRSimulation.scala index 5f638c90..1e22160f 100644 --- a/examples/src/main/scala/basic/SIRSimulation.scala +++ b/examples/src/main/scala/basic/SIRSimulation.scala @@ -1,6 +1,6 @@ package src.main.scala.basic -import dimwit._ +import dimwit.* /** A simple SIR (Susceptible-Infectious-Recovered) simulation. */ diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index c56c92c8..a075f996 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -1,9 +1,9 @@ package examples.complex.vae import dimwit.Conversions.given -import dimwit._ -import dimwit.autodiff.FloatTree._ -import dimwit.autodiff._ +import dimwit.* +import dimwit.autodiff.FloatTree.* +import dimwit.autodiff.* import dimwit.nn.ActivationFunctions.relu import dimwit.nn.ActivationFunctions.sigmoid import dimwit.optimizer.GradientDescent diff --git a/examples/src/main/scala/dataset/MNISTLoader.scala b/examples/src/main/scala/dataset/MNISTLoader.scala index e0c2d6af..2bcd8e9d 100644 --- a/examples/src/main/scala/dataset/MNISTLoader.scala +++ b/examples/src/main/scala/dataset/MNISTLoader.scala @@ -1,7 +1,7 @@ package examples.dataset import dimwit.Conversions.given -import dimwit._ +import dimwit.* import me.shadaj.scalapy.py import java.io.RandomAccessFile