diff --git a/build.sbt b/build.sbt index fc2f4ec6..383f472d 100644 --- a/build.sbt +++ b/build.sbt @@ -34,6 +34,22 @@ ThisBuild / developers := List( ) ) +lazy val commonScalacOptions = Seq( + "-deprecation", + "-unchecked", + "-Wunused:imports", + "-explain-cyclic" +) + +// Setup for Scalafix and SemanticDB +inThisBuild(Seq( + semanticdbEnabled := true, + semanticdbVersion := scalafixSemanticdb.revision +)) + +ThisBuild / scalafixDependencies += + "com.github.liancheng" %% "organize-imports" % "0.6.0" + addCommandAlias("testAndCoverage", "; clean; coverage; test; coverageReport") lazy val root = (project in file(".")) @@ -66,7 +82,8 @@ lazy val core = (project in file("core")) coverageFailOnMinimum := false, coverageHighlighting := true, Compile / packageSrc / publishArtifact := true, - Compile / packageDoc / publishArtifact := true + Compile / packageDoc / publishArtifact := true, + scalacOptions ++= commonScalacOptions ) // Examples subproject @@ -87,6 +104,7 @@ lazy val examples = (project in file("examples")) // Examples source directory Compile / scalaSource := baseDirectory.value, Compile / resourceDirectory := baseDirectory.value / "src" / "main" / "resources", + scalacOptions ++= commonScalacOptions, scalafmtFailOnErrors := false, javaOptions ++= { if (sys.props("os.name").toLowerCase.contains("mac")) { diff --git a/core/src/main/scala/dimwit/MemoryHelper.scala b/core/src/main/scala/dimwit/MemoryHelper.scala index f59f3e63..ac7604a9 100644 --- a/core/src/main/scala/dimwit/MemoryHelper.scala +++ b/core/src/main/scala/dimwit/MemoryHelper.scala @@ -1,15 +1,15 @@ package dimwit -import me.shadaj.scalapy.py import dimwit.autodiff.TensorTree +import me.shadaj.scalapy.py private[dimwit] object MemoryHelper: - def withLocalCleanup(f: => Unit): Unit = + private[dimwit] def withLocalCleanupImpl(f: => Unit): Unit = py.local: f - def withLocalCleanup[A: TensorTree](f: => A): A = + private[dimwit] def withLocalCleanupImpl[A: TensorTree](f: => A): A = val lifeRaft = me.shadaj.scalapy.py.Dynamic.global.list() py.local: val res = f diff --git a/core/src/main/scala/dimwit/OnError.scala b/core/src/main/scala/dimwit/OnError.scala index 5060f042..9d06a575 100644 --- a/core/src/main/scala/dimwit/OnError.scala +++ b/core/src/main/scala/dimwit/OnError.scala @@ -1,7 +1,7 @@ package dimwit -import java.io.StringWriter import java.io.PrintWriter +import java.io.StringWriter object OnError: diff --git a/core/src/main/scala/dimwit/autodiff/Autodiff.scala b/core/src/main/scala/dimwit/autodiff/Autodiff.scala index b26028da..871733c1 100644 --- a/core/src/main/scala/dimwit/autodiff/Autodiff.scala +++ b/core/src/main/scala/dimwit/autodiff/Autodiff.scala @@ -1,13 +1,12 @@ package dimwit.autodiff import dimwit.OnError -import dimwit.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape} -import dimwit.tensor.ShapeTypeHelpers.AxisIndices -import dimwit.tensor.TupleHelpers.PrimeConcatType import dimwit.jax.Jax -import me.shadaj.scalapy.py -import dimwit.tensor.Label +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TupleHelpers.PrimeConcatType +import me.shadaj.scalapy.py object Autodiff: diff --git a/core/src/main/scala/dimwit/autodiff/FloatTree.scala b/core/src/main/scala/dimwit/autodiff/FloatTree.scala index 32128c15..bf0b1cba 100644 --- a/core/src/main/scala/dimwit/autodiff/FloatTree.scala +++ b/core/src/main/scala/dimwit/autodiff/FloatTree.scala @@ -1,9 +1,9 @@ package dimwit.autodiff -import dimwit.tensor.* import dimwit.tensor.TensorOps.* +import dimwit.tensor.* + import scala.deriving.* -import scala.compiletime.* import scala.util.NotGiven /** A marker trait for structures that are trees of floating-point tensors. diff --git a/core/src/main/scala/dimwit/autodiff/Grad.scala b/core/src/main/scala/dimwit/autodiff/Grad.scala index aec27598..27f99640 100644 --- a/core/src/main/scala/dimwit/autodiff/Grad.scala +++ b/core/src/main/scala/dimwit/autodiff/Grad.scala @@ -2,6 +2,7 @@ package dimwit.autodiff import dimwit.* import dimwit.jax.Jax + import scala.deriving.Mirror /** Type-level tag marking a parameter structure as gradients. diff --git a/core/src/main/scala/dimwit/autodiff/TensorTree.scala b/core/src/main/scala/dimwit/autodiff/TensorTree.scala index 3e403ee0..90fcfb81 100644 --- a/core/src/main/scala/dimwit/autodiff/TensorTree.scala +++ b/core/src/main/scala/dimwit/autodiff/TensorTree.scala @@ -1,12 +1,12 @@ package dimwit.autodiff +import dimwit.jax.Jax import dimwit.tensor.* -import dimwit.random.Random -import scala.deriving.* import scala.compiletime.* +import scala.deriving.* + import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.jax.Jax /** A typeclass for structures that can be represented as a tree of tensors, * which can be mapped over. Most often, a tensor tree is used to structure diff --git a/core/src/main/scala/dimwit/hardware/Device.scala b/core/src/main/scala/dimwit/hardware/Device.scala index caf90085..69a01e79 100644 --- a/core/src/main/scala/dimwit/hardware/Device.scala +++ b/core/src/main/scala/dimwit/hardware/Device.scala @@ -1,12 +1,7 @@ package dimwit.hardware import dimwit.* -import dimwit.tensor.TupleHelpers -import dimwit.jax.Jax import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.Writer -import me.shadaj.scalapy.interpreter.PyValue case class Device private[dimwit] (private[dimwit] val jaxDevice: py.Dynamic): diff --git a/core/src/main/scala/dimwit/hardware/DeviceBackend.scala b/core/src/main/scala/dimwit/hardware/DeviceBackend.scala index 0e6bcbc0..0aa5715a 100644 --- a/core/src/main/scala/dimwit/hardware/DeviceBackend.scala +++ b/core/src/main/scala/dimwit/hardware/DeviceBackend.scala @@ -1,7 +1,7 @@ package dimwit.hardware -import me.shadaj.scalapy.py import dimwit.jax.Jax +import me.shadaj.scalapy.py enum DeviceBackend(private[dimwit] val jaxName: String): diff --git a/core/src/main/scala/dimwit/jax/Einops.scala b/core/src/main/scala/dimwit/jax/Einops.scala index 5043a8e7..18d8b607 100644 --- a/core/src/main/scala/dimwit/jax/Einops.scala +++ b/core/src/main/scala/dimwit/jax/Einops.scala @@ -2,7 +2,6 @@ package dimwit.jax import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.py.PyQuote object Einops: diff --git a/core/src/main/scala/dimwit/jax/Jax.scala b/core/src/main/scala/dimwit/jax/Jax.scala index ead7d4a6..e50e04d7 100644 --- a/core/src/main/scala/dimwit/jax/Jax.scala +++ b/core/src/main/scala/dimwit/jax/Jax.scala @@ -1,10 +1,9 @@ package dimwit.jax -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.py.PyQuote -import dimwit.hardware.{Device, DeviceBackend} +import dimwit.hardware.Device +import dimwit.hardware.DeviceBackend import dimwit.python.PythonSetup +import me.shadaj.scalapy.py object Jax: diff --git a/core/src/main/scala/dimwit/jax/JaxDType.scala b/core/src/main/scala/dimwit/jax/JaxDType.scala index 34d9ee0a..ffefedf7 100644 --- a/core/src/main/scala/dimwit/jax/JaxDType.scala +++ b/core/src/main/scala/dimwit/jax/JaxDType.scala @@ -1,7 +1,7 @@ package dimwit.jax -import me.shadaj.scalapy.py import dimwit.tensor.DType +import me.shadaj.scalapy.py object JaxDType: diff --git a/core/src/main/scala/dimwit/jax/Jit.scala b/core/src/main/scala/dimwit/jax/Jit.scala index 665b335f..c10c4293 100644 --- a/core/src/main/scala/dimwit/jax/Jit.scala +++ b/core/src/main/scala/dimwit/jax/Jit.scala @@ -1,13 +1,12 @@ package dimwit.jax -import dimwit.tensor.{Tensor, Shape, Labels} -import dimwit.jax.{Jax, JaxDType} +import dimwit.OnError import dimwit.autodiff.TensorTree +import dimwit.jax.Jax +import dimwit.jax.Jax.PyDynamic import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.jax.Jax.PyDynamic -import me.shadaj.scalapy.py.PythonException -import dimwit.OnError + import scala.annotation.targetName object Jit: @@ -142,7 +141,7 @@ object JitDefault: object EagerCleanup: - import dimwit.MemoryHelper.withLocalCleanup + import dimwit.withLocalCleanup def eagerCleanup[T1, R: TensorTree](f: T1 => R): T1 => R = (t1) => withLocalCleanup: diff --git a/core/src/main/scala/dimwit/nn/ActivationFunctions.scala b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala index 542419f4..ba344274 100644 --- a/core/src/main/scala/dimwit/nn/ActivationFunctions.scala +++ b/core/src/main/scala/dimwit/nn/ActivationFunctions.scala @@ -1,9 +1,10 @@ package dimwit.nn -import dimwit.tensor.* -import dimwit.tensor.TensorOps.IsFloating import dimwit.jax.Jax -import dimwit.python.PyBridge.{liftPyTensor, toPyTensor} +import dimwit.python.PyBridge.liftPyTensor +import dimwit.python.PyBridge.toPyTensor +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.* object ActivationFunctions: diff --git a/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala index 2ffaa87e..17a2d535 100644 --- a/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala +++ b/core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala @@ -1,11 +1,9 @@ package dimwit.optimizer import dimwit.* -import dimwit.autodiff.FloatTree.ops.* import dimwit.autodiff.FloatTree.* +import dimwit.autodiff.FloatTree.ops.* import dimwit.autodiff.* -import dimwit.jax.Jax -import dimwit.jax.Jit /** Gradient optimizer interface with functional state management. * diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 168f26b8..5242a5a1 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -1,7 +1,13 @@ -import scala.annotation.targetName - import dimwit.jax.Jax -import dimwit.tensor.{Axis, AxisExtent, AxisSelector, AxisAtIndex, AxisAtRange, AxisAtIndices, AxisAtTensorIndex} +import dimwit.tensor.Axis +import dimwit.tensor.AxisAtIndex +import dimwit.tensor.AxisAtIndices +import dimwit.tensor.AxisAtRange +import dimwit.tensor.AxisAtTensorIndex +import dimwit.tensor.AxisExtent +import dimwit.tensor.AxisSelector + +import scala.annotation.targetName package object dimwit: import scala.compiletime.ops.string.+ @@ -93,7 +99,16 @@ package object dimwit: // export some stats types export dimwit.stats.{Prob, LogProb} export dimwit.stats.{Distribution, IndependentDistribution, MultivariateDistribution, UnivariateDistribution} - export dimwit.MemoryHelper.withLocalCleanup + + /** Memory management helpfer making sure + * all python objects allocated ar freed + * after the function is executed. + */ + def withLocalCleanup(f: => Unit): Unit = + MemoryHelper.withLocalCleanupImpl(f) + + def withLocalCleanup[A: TensorTree](f: => A): A = + MemoryHelper.withLocalCleanupImpl(f) /** Explicitly configures the Python environment before any ScalaPy call. * Call this function at the start of your program (before any `py.*` call) diff --git a/core/src/main/scala/dimwit/python/PyBridge.scala b/core/src/main/scala/dimwit/python/PyBridge.scala index 2a9ce942..d3da1a7a 100644 --- a/core/src/main/scala/dimwit/python/PyBridge.scala +++ b/core/src/main/scala/dimwit/python/PyBridge.scala @@ -1,9 +1,9 @@ package dimwit.python -import dimwit.tensor.* -import dimwit.jax.Jax -import dimwit.autodiff.TensorTree import dimwit.OnError +import dimwit.autodiff.TensorTree +import dimwit.jax.Jax +import dimwit.tensor.* import me.shadaj.scalapy.py object PyBridge: diff --git a/core/src/main/scala/dimwit/python/PythonSetup.scala b/core/src/main/scala/dimwit/python/PythonSetup.scala index 1eee00a3..cf19e6c6 100644 --- a/core/src/main/scala/dimwit/python/PythonSetup.scala +++ b/core/src/main/scala/dimwit/python/PythonSetup.scala @@ -1,7 +1,9 @@ package dimwit.python import me.shadaj.scalapy.py -import scala.sys.process.{Process, ProcessLogger} + +import scala.sys.process.Process +import scala.sys.process.ProcessLogger /** Manages Python environment setup for DimWit. * diff --git a/core/src/main/scala/dimwit/random/Random.scala b/core/src/main/scala/dimwit/random/Random.scala index 492c22d5..d6acbf0a 100644 --- a/core/src/main/scala/dimwit/random/Random.scala +++ b/core/src/main/scala/dimwit/random/Random.scala @@ -1,15 +1,14 @@ package dimwit.random -import dimwit.tensor.* -import dimwit.tensor.DType.Int32 -import dimwit.tensor.TensorOps.* -import dimwit.jax.{Jax, JaxDType} import dimwit.autodiff.TensorTree -import me.shadaj.scalapy.py.SeqConverters +import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor -import scala.compiletime.{requireConst, constValue, ops} -import Tuple.Size +import dimwit.tensor.DType.Int32 +import dimwit.tensor.TensorOps.* import dimwit.tensor.TupleHelpers.TupleNOf +import dimwit.tensor.* + +import scala.compiletime.requireConst /** JAX-based random number generation with proper key management. * diff --git a/core/src/main/scala/dimwit/stats/Distributions.scala b/core/src/main/scala/dimwit/stats/Distributions.scala index a20f7090..8efb99d1 100644 --- a/core/src/main/scala/dimwit/stats/Distributions.scala +++ b/core/src/main/scala/dimwit/stats/Distributions.scala @@ -2,9 +2,6 @@ package dimwit.stats import dimwit.* import dimwit.random.Random -import dimwit.jax.Jax -import dimwit.jax.Jax.scipy_stats as jstats -import dimwit.jax.Jax.PyDynamic import dimwit.tensor.TensorOps opaque type LogProb = Float32 diff --git a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala index cccdaf0c..975ed4a7 100644 --- a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala +++ b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala @@ -1,14 +1,13 @@ package dimwit.stats -import dimwit.* import dimwit.DType.Float32 -import dimwit.jax.Jax.scipy_stats as jstats +import dimwit.* import dimwit.jax.Jax -import dimwit.jax.Jax.PyDynamic +import dimwit.jax.Jax.scipy_stats as jstats +import dimwit.python.PyBridge.liftPyTensor +import dimwit.random.Random import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.random.Random -import dimwit.python.PyBridge.liftPyTensor /** Normal (Gaussian) distribution */ class Normal[T <: Tuple: Labels, V: IsFloating](val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]: diff --git a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala index ed66bc7b..1c428222 100644 --- a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala @@ -1,12 +1,10 @@ package dimwit.stats import dimwit.* -import dimwit.random.Random import dimwit.jax.Jax import dimwit.jax.Jax.scipy_stats as jstats -import dimwit.jax.Jax.PyDynamic import dimwit.python.PyBridge.liftPyTensor -import me.shadaj.scalapy.readwrite.Reader +import dimwit.random.Random /** Distribution over a vector of random variables. */ diff --git a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala index 3bf66df5..06c5b25b 100644 --- a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala @@ -1,8 +1,8 @@ package dimwit.stats +import dimwit.DType.Float32 +import dimwit.DType.Int32 import dimwit.* -import dimwit.DType.{Int32, Float32} -import dimwit.random.* import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor diff --git a/core/src/main/scala/dimwit/tensor/ArrayReader.scala b/core/src/main/scala/dimwit/tensor/ArrayReader.scala index 42b0b2b2..a06b4adc 100644 --- a/core/src/main/scala/dimwit/tensor/ArrayReader.scala +++ b/core/src/main/scala/dimwit/tensor/ArrayReader.scala @@ -1,8 +1,9 @@ package dimwit.tensor +import me.shadaj.scalapy.py + import java.nio.ByteBuffer import java.nio.ByteOrder -import me.shadaj.scalapy.py /** Utility object for reading flat arrays of scalar values from JAX tensors. */ diff --git a/core/src/main/scala/dimwit/tensor/ArrayWriter.scala b/core/src/main/scala/dimwit/tensor/ArrayWriter.scala index fcd4524e..bb453fcb 100644 --- a/core/src/main/scala/dimwit/tensor/ArrayWriter.scala +++ b/core/src/main/scala/dimwit/tensor/ArrayWriter.scala @@ -1,14 +1,14 @@ package dimwit.tensor -import java.nio.ByteBuffer -import java.util.Base64 -import java.nio.ByteOrder +import dimwit.jax.Jax +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import me.shadaj.scalapy.readwrite.Writer -import me.shadaj.scalapy.interpreter.PyValue -import dimwit.jax.Jax -import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} + +import java.util.Base64 object ArrayWriter: diff --git a/core/src/main/scala/dimwit/tensor/Axis.scala b/core/src/main/scala/dimwit/tensor/Axis.scala index 597f1063..583d482e 100644 --- a/core/src/main/scala/dimwit/tensor/Axis.scala +++ b/core/src/main/scala/dimwit/tensor/Axis.scala @@ -1,9 +1,8 @@ package dimwit.tensor -import dimwit.|*| import dimwit.tensor.DType.Int32 +import dimwit.|*| -import scala.compiletime.{constValue, erasedValue, summonInline} import ShapeTypeHelpers.AxisIndex /** Instances of this class represent an axis in a tensor with a specific label `L`. diff --git a/core/src/main/scala/dimwit/tensor/DType.scala b/core/src/main/scala/dimwit/tensor/DType.scala index 8365a456..d0f26f0d 100644 --- a/core/src/main/scala/dimwit/tensor/DType.scala +++ b/core/src/main/scala/dimwit/tensor/DType.scala @@ -1,10 +1,13 @@ package dimwit.tensor import dimwit.jax.JaxDType +import dimwit.tensor.HasScalar +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger +import me.shadaj.scalapy.py + import java.nio.ByteBuffer import java.nio.ByteOrder -import me.shadaj.scalapy.py -import dimwit.tensor.TensorOps.{IsFloating, IsInteger, IsBoolean} -import dimwit.tensor.HasScalar object DType: diff --git a/core/src/main/scala/dimwit/tensor/HasScalar.scala b/core/src/main/scala/dimwit/tensor/HasScalar.scala index 727f4f2f..66d89202 100644 --- a/core/src/main/scala/dimwit/tensor/HasScalar.scala +++ b/core/src/main/scala/dimwit/tensor/HasScalar.scala @@ -1,8 +1,9 @@ package dimwit.tensor -import scala.annotation.implicitNotFound import me.shadaj.scalapy.py +import scala.annotation.implicitNotFound + @implicitNotFound( "No Scala type mapping for DType ${V}. Supported: Bool, Int8, Int16, Int32, Int64, Float32, Float64." ) diff --git a/core/src/main/scala/dimwit/tensor/Labels.scala b/core/src/main/scala/dimwit/tensor/Labels.scala index 38902a61..d565c3c7 100644 --- a/core/src/main/scala/dimwit/tensor/Labels.scala +++ b/core/src/main/scala/dimwit/tensor/Labels.scala @@ -1,9 +1,8 @@ package dimwit.tensor -import scala.compiletime.* import scala.quoted.* + import Tuple.:* -import dimwit.tensor.ShapeTypeHelpers.MergeLabels @scala.annotation.implicitNotFound(""" An axis label ${T} was given or inferred, which does not have a Label instance. diff --git a/core/src/main/scala/dimwit/tensor/Shape.scala b/core/src/main/scala/dimwit/tensor/Shape.scala index dec02a9f..da1d4f09 100644 --- a/core/src/main/scala/dimwit/tensor/Shape.scala +++ b/core/src/main/scala/dimwit/tensor/Shape.scala @@ -1,9 +1,11 @@ package dimwit.tensor -import scala.collection.View.Empty +import dimwit.tensor.Label +import dimwit.tensor.Labels + import scala.annotation.publicInBinary + import ShapeTypeHelpers.AxisIndex -import dimwit.tensor.{Labels, Label} /** Represents the Shape of a tensor. Conceptually, a shape is an order list of AxisExtents, * where each AxisExtent is a label associated with a size. diff --git a/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala b/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala index 9093a667..4192921d 100644 --- a/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala +++ b/core/src/main/scala/dimwit/tensor/ShapeTypeHelpers.scala @@ -1,8 +1,9 @@ package dimwit.tensor import scala.annotation.implicitNotFound +import scala.compiletime.erasedValue +import scala.compiletime.summonInline import scala.util.NotGiven -import scala.compiletime.{constValue, erasedValue, summonInline} /* Helpers for tracking Tensor Shape types across various operations */ object ShapeTypeHelpers: diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index ca9dfa0b..be29a81c 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -1,24 +1,25 @@ package dimwit.tensor -import scala.annotation.targetName -import scala.compiletime.{erasedValue, summonFrom} +import dimwit.Prime +import dimwit.hardware.Device import dimwit.jax.Jax -import dimwit.jax.JaxDType import dimwit.jax.Jax.PyDynamic +import dimwit.jax.JaxDType +import dimwit.tensor.Label +import dimwit.tensor.Labels +import dimwit.tensor.TensorOps.IsBoolean +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger import dimwit.tensor.TypedIndex -import dimwit.tensor.{Label, Labels, VType} +import dimwit.tensor.VType import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import dimwit.random.Random -import dimwit.stats.{Normal, Uniform} import me.shadaj.scalapy.readwrite.Writer + +import scala.annotation.targetName import scala.reflect.ClassTag -import scala.annotation.unchecked.uncheckedVariance -import dimwit.Prime + import ShapeTypeHelpers.AxisIndex -import dimwit.hardware.Device -import me.shadaj.scalapy.readwrite.Writer.stringWriter.given -import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} import DType.* /** A tensor with a fixed shape and data type. diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index 0636a4a5..3f7dc8d2 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -1,31 +1,13 @@ package dimwit.tensor -import dimwit.DType.* -import dimwit.DType.given -import dimwit.OnError -import dimwit.jax.Jax import dimwit.tensor.HasScalar import dimwit.tensor.Label import dimwit.tensor.Labels import dimwit.tensor.ShapeTypeHelpers.* -import dimwit.tensor.TensorOps.ZipVmap.ShapesOf -import dimwit.tensor.TensorOps.ZipVmap.TensorsOf import dimwit.tensor.TupleHelpers.* -import dimwit.tensor.tensorops.StructuralOps -import dimwit.|*| -import dimwit.|+| -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.Reader -import me.shadaj.scalapy.readwrite.Writer import scala.annotation.implicitNotFound import scala.annotation.targetName -import scala.compiletime.ops.int.<= -import scala.util.NotGiven - -import Tuple.:* -import Tuple.++ object TensorOps: diff --git a/core/src/main/scala/dimwit/tensor/TupleHelpers.scala b/core/src/main/scala/dimwit/tensor/TupleHelpers.scala index a854c6f6..f6379d6b 100644 --- a/core/src/main/scala/dimwit/tensor/TupleHelpers.scala +++ b/core/src/main/scala/dimwit/tensor/TupleHelpers.scala @@ -1,14 +1,10 @@ package dimwit.tensor -import scala.util.NotGiven -import scala.annotation.implicitNotFound - -import scala.compiletime.{constValue, error} -import scala.compiletime.ops.string.+ -import scala.quoted.Type -import scala.quoted.Quotes -import scala.quoted.Expr import scala.compiletime.ops +import scala.quoted.Expr +import scala.quoted.Quotes +import scala.quoted.Type +import scala.util.NotGiven /* Helpers for manipulating Tuple types */ object TupleHelpers: @@ -100,8 +96,6 @@ object TupleHelpers: case _ => T *: TupleNOf[ops.int.-[N, 1], T] import dimwit.Prime - import scala.compiletime.ops.boolean.* - import scala.compiletime.ops.boolean.* type Member[X, T <: Tuple] <: Boolean = T match case EmptyTuple => false @@ -193,7 +187,6 @@ object TupleHelpers: // If the result is MissingAxis[A], it fails with your message. sealed trait CheckValid[R <: ValidationResult] - import scala.compiletime.ops.any.ToString object CheckValid: // Case 1: Success. We provide an instance, so compilation proceeds. given ok: CheckValid[AllOk] = new CheckValid[AllOk] {} diff --git a/core/src/main/scala/dimwit/tensor/VType.scala b/core/src/main/scala/dimwit/tensor/VType.scala index 78782b70..80ec6fc2 100644 --- a/core/src/main/scala/dimwit/tensor/VType.scala +++ b/core/src/main/scala/dimwit/tensor/VType.scala @@ -1,9 +1,5 @@ package dimwit.tensor -import dimwit.stats.Prob -import dimwit.stats.LogProb -import scala.compiletime.ops.double -import java.nio.ByteBuffer import dimwit.tensor.TensorOps.HasDType object VType: diff --git a/core/src/main/scala/dimwit/tensor/ValueOps.scala b/core/src/main/scala/dimwit/tensor/ValueOps.scala index f5c23767..9dcbd94e 100644 --- a/core/src/main/scala/dimwit/tensor/ValueOps.scala +++ b/core/src/main/scala/dimwit/tensor/ValueOps.scala @@ -1,12 +1,12 @@ package dimwit.tensor -import dimwit.tensor.TensorOps.IsNumber import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.tensorops.ElementWiseOps.subtract +import dimwit.tensor.tensorops.ElementWiseOps.divide import dimwit.tensor.tensorops.ElementWiseOps.multiply +import dimwit.tensor.tensorops.ElementWiseOps.subtract import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.divide object ValueOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala index 5074592f..9193e3e6 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ContractionOps.scala @@ -1,29 +1,15 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - +import dimwit.tensor.Labels +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.Tensor +import dimwit.tensor.TupleHelpers.PrimeConcat import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.TupleHelpers.PrimeConcat +import me.shadaj.scalapy.readwrite.Writer + import scala.annotation.targetName object ContractionOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala index 9a4f84da..5f5f1fff 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ConvolutionOps.scala @@ -1,30 +1,17 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.AxisExtent +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - +import dimwit.tensor.Tensor +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.swap import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.AxisExtent -import dimwit.tensor.TensorOps.swap +import me.shadaj.scalapy.readwrite.Writer object ConvolutionOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala index a7ecdbdd..1a30cba8 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ElementWiseOps.scala @@ -1,17 +1,17 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax import dimwit.tensor.DType.Bool +import dimwit.tensor.DType.Float32 +import dimwit.tensor.DType.Int32 +import dimwit.tensor.Labels +import dimwit.tensor.Tensor import dimwit.tensor.Tensor0 import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsInteger import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.VType import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast object ElementWiseOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala index 39e0bff9..8e0aa041 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/FunctionalOps.scala @@ -1,34 +1,21 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels +import dimwit.OnError import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.ShapeTypeHelpers.SharedAxisRemover -import dimwit.OnError +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.LabelsImpl +import dimwit.tensor.ShapeTypeHelpers.AxisRemover import dimwit.tensor.ShapeTypeHelpers.AxisReplacer +import dimwit.tensor.ShapeTypeHelpers.SharedAxisRemover +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 import dimwit.tensor.tensorops.FunctionalOps.ZipVmap.TensorsOf +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.Reader +import me.shadaj.scalapy.readwrite.Writer object FunctionalOps: // ----------------------------------------------------------- diff --git a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala index d6fb6104..06163253 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala @@ -1,31 +1,19 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool +import dimwit.tensor.Axis +import dimwit.tensor.Label +import dimwit.tensor.Labels +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Tensor import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger +import dimwit.tensor.Tensor1 +import dimwit.tensor.Tensor2 import dimwit.tensor.TensorOps.IsFloating import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover -import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices - import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.Tensor2 -import dimwit.tensor.Tensor1 +import me.shadaj.scalapy.readwrite.Writer object LinearAlgebraOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala index a529e2d4..ff3131b7 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/ReductionOps.scala @@ -1,29 +1,22 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType +import dimwit.tensor.Axis import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.Labels import dimwit.tensor.ShapeTypeHelpers.AxesRemover -import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes import dimwit.tensor.ShapeTypeHelpers.AxisIndex import dimwit.tensor.ShapeTypeHelpers.AxisIndices - +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import me.shadaj.scalapy.readwrite.Writer object ReductionOps: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala index de0abcbd..700d0747 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/StructuralOps.scala @@ -1,55 +1,47 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor -import dimwit.tensor.Labels +import dimwit.jax.Einops import dimwit.jax.Jax -import dimwit.tensor.DType.Bool -import dimwit.tensor.Tensor0 -import dimwit.tensor.TensorOps.IsBoolean -import dimwit.tensor.VType -import dimwit.tensor.DType.Int32 -import dimwit.tensor.DType.Float32 -import dimwit.tensor.TensorOps.IsInteger -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.Label -import dimwit.tensor.ShapeTypeHelpers.AxisRemover -import dimwit.tensor.ShapeTypeHelpers.AxesRemover import dimwit.tensor.Axis -import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes -import dimwit.tensor.ShapeTypeHelpers.AxisIndex -import dimwit.tensor.ShapeTypeHelpers.AxisIndices -import dimwit.{`|*|`, `|+|`} - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} import dimwit.tensor.AxisAtIndex -import dimwit.tensor.AxisAtRange import dimwit.tensor.AxisAtIndices -import dimwit.tensor.AxisAtTupleIndices +import dimwit.tensor.AxisAtRange import dimwit.tensor.AxisAtTensorIndex -import scala.util.NotGiven -import scala.annotation.implicitNotFound -import dimwit.tensor.TupleHelpers -import dimwit.tensor.ShapeTypeHelpers.DimExtractor +import dimwit.tensor.AxisAtTupleIndices import dimwit.tensor.AxisExtent -import dimwit.tensor.Tensor1 -import dimwit.tensor.ShapeTypeHelpers.MergeLabels -import dimwit.tensor.ShapeTypeHelpers.AxesMerger +import dimwit.tensor.DType.Bool +import dimwit.tensor.DType.Int32 +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.Shape +import dimwit.tensor.ShapeTypeHelpers.AxesConditionalRemover +import dimwit.tensor.ShapeTypeHelpers.AxesMerger +import dimwit.tensor.ShapeTypeHelpers.AxisIndex +import dimwit.tensor.ShapeTypeHelpers.AxisIndices +import dimwit.tensor.ShapeTypeHelpers.AxisRemover +import dimwit.tensor.ShapeTypeHelpers.AxisReplacer import dimwit.tensor.ShapeTypeHelpers.AxisReplacerAll +import dimwit.tensor.ShapeTypeHelpers.DimExtractor +import dimwit.tensor.ShapeTypeHelpers.MergeLabels +import dimwit.tensor.ShapeTypeHelpers.UnwrapAxes +import dimwit.tensor.ShapeTypeHelpers.UnwrapDims +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 +import dimwit.tensor.Tensor1 +import dimwit.tensor.TupleHelpers +import dimwit.tensor.TupleHelpers.StrictSubset +import dimwit.tensor.TupleHelpers.TensorEvidence.CheckValid +import dimwit.tensor.TupleHelpers.TensorEvidence.ComputeMissing import dimwit.tensor.TupleHelpers.TensorEvidence.IsPermutation import dimwit.tensor.TupleHelpers.TensorEvidence.ValidationResult -import dimwit.tensor.ShapeTypeHelpers.AxesConditionalRemover -import dimwit.tensor.TupleHelpers.TensorEvidence.ComputeMissing -import dimwit.tensor.TupleHelpers.TensorEvidence.CheckValid -import dimwit.tensor.ShapeTypeHelpers.UnwrapDims +import dimwit.|+| +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters +import me.shadaj.scalapy.readwrite.Reader +import me.shadaj.scalapy.readwrite.Writer -import dimwit.jax.Einops -import dimwit.tensor.TupleHelpers.StrictSubset -import dimwit.tensor.ShapeTypeHelpers.AxisReplacer +import scala.annotation.implicitNotFound +import scala.util.NotGiven // ----------------------------------------------------------- // 4. Structural Operations (Isomorphisms) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala index 5f8d3d0d..32644db0 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor0Ops.scala @@ -1,17 +1,7 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.Tensor0 object Tensor0Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala index bef2a6c3..98b20457 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor1Ops.scala @@ -1,26 +1,17 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 +import dimwit.jax.Jax +import dimwit.tensor.Axis import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber +import dimwit.tensor.HasScalar +import dimwit.tensor.Label import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.Axis -import dimwit.tensor.Label +import dimwit.tensor.Tensor0 import dimwit.tensor.Tensor1 -import dimwit.tensor.HasScalar -import dimwit.jax.Jax - import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import me.shadaj.scalapy.readwrite.Writer object Tensor1Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala index c33e94f8..0bf95e94 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor2Ops.scala @@ -1,26 +1,9 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 -import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating import dimwit.tensor.Axis -import dimwit.tensor.Label -import dimwit.tensor.Tensor1 import dimwit.tensor.HasScalar -import dimwit.jax.Jax - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} +import dimwit.tensor.Label +import dimwit.tensor.Labels import dimwit.tensor.Tensor2 object Tensor2Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala b/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala index 222afab6..8b705168 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/Tensor3Ops.scala @@ -1,27 +1,6 @@ package dimwit.tensor.tensorops -import dimwit.tensor.Tensor0 -import dimwit.tensor.DType.* -import dimwit.tensor.TensorOps -import dimwit.tensor.TensorOps.IsNumber -import dimwit.tensor.Labels -import dimwit.tensor.tensorops.ElementWiseOps.add -import dimwit.tensor.Tensor -import dimwit.tensor.tensorops.TensorOpsUtil.Broadcast -import dimwit.tensor.tensorops.ElementWiseOps.subtract -import dimwit.tensor.tensorops.ElementWiseOps.multiply -import dimwit.tensor.tensorops.ElementWiseOps.divide -import dimwit.tensor.TensorOps.IsFloating -import dimwit.tensor.Axis -import dimwit.tensor.Label -import dimwit.tensor.Tensor1 import dimwit.tensor.HasScalar -import dimwit.jax.Jax - -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.{Reader, Writer} -import dimwit.tensor.Tensor2 import dimwit.tensor.Tensor3 object Tensor3Ops: diff --git a/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala b/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala index fa9b577b..76555e0b 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/TensorOpsUtils.scala @@ -1,10 +1,11 @@ package dimwit.tensor.tensorops -import scala.annotation.implicitNotFound import dimwit.tensor.Labels import dimwit.tensor.Tensor import dimwit.tensor.TupleHelpers.StrictSubset +import scala.annotation.implicitNotFound + object TensorOpsUtil: import dimwit.tensor.TensorOps.broadcastTo diff --git a/examples/src/main/scala/basic/KMeans.scala b/examples/src/main/scala/basic/KMeans.scala index d3b26353..0d129c7b 100644 --- a/examples/src/main/scala/basic/KMeans.scala +++ b/examples/src/main/scala/basic/KMeans.scala @@ -1,7 +1,7 @@ package examples.basic.kmeans -import dimwit.* import dimwit.Conversions.given +import dimwit.* import dimwit.random.Random import dimwit.stats.Normal diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 06a69a0f..1b35cf1c 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -1,10 +1,11 @@ package examples.basic -import dimwit.* import dimwit.Conversions.given +import dimwit.* import dimwit.autodiff.* +import dimwit.nn.ActivationFunctions.relu +import dimwit.nn.ActivationFunctions.sigmoid import dimwit.optimizer.GradientDescent -import dimwit.nn.ActivationFunctions.{sigmoid, relu} import dimwit.random.Random import dimwit.stats.Normal diff --git a/examples/src/main/scala/basic/SIRSimulation.scala b/examples/src/main/scala/basic/SIRSimulation.scala index 8ccf4e34..1e22160f 100644 --- a/examples/src/main/scala/basic/SIRSimulation.scala +++ b/examples/src/main/scala/basic/SIRSimulation.scala @@ -1,8 +1,6 @@ package src.main.scala.basic import dimwit.* -import dimwit.Conversions.given -import dimwit.autodiff.* /** A simple SIR (Susceptible-Infectious-Recovered) simulation. */ diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index f42fa65d..a075f996 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -1,20 +1,23 @@ package examples.complex.vae -import dimwit.* import dimwit.Conversions.given -import dimwit.autodiff.* +import dimwit.* import dimwit.autodiff.FloatTree.* -import dimwit.stats.Normal -import dimwit.random.Random -import examples.dataset.MNISTLoader +import dimwit.autodiff.* import dimwit.nn.ActivationFunctions.relu -import dimwit.optimizer.GradientDescent -import dimwit.jax.Jax import dimwit.nn.ActivationFunctions.sigmoid +import dimwit.optimizer.GradientDescent +import dimwit.python.PyBridge.toPyTensor +import dimwit.random.Random import dimwit.random.Random.Key +import dimwit.stats.Normal +import examples.dataset.MNISTLoader -import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width} -import dimwit.python.PyBridge.toPyTensor +import MNISTLoader.Sample +import MNISTLoader.TrainSample +import MNISTLoader.TestSample +import MNISTLoader.Height +import MNISTLoader.Width type Pixel = Height |*| Width type ReconstructedPixel = Height |*| Width diff --git a/examples/src/main/scala/dataset/MNISTLoader.scala b/examples/src/main/scala/dataset/MNISTLoader.scala index 2fc0112e..2bcd8e9d 100644 --- a/examples/src/main/scala/dataset/MNISTLoader.scala +++ b/examples/src/main/scala/dataset/MNISTLoader.scala @@ -1,10 +1,9 @@ package examples.dataset -import dimwit.* import dimwit.Conversions.given - +import dimwit.* import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters + import java.io.RandomAccessFile import scala.util.Try diff --git a/project/.scalafix.conf b/project/.scalafix.conf new file mode 100644 index 00000000..f184450f --- /dev/null +++ b/project/.scalafix.conf @@ -0,0 +1,21 @@ +rules = [ + OrganizeImports +] + +OrganizeImports { + groups = [ + "re:java\\..*" + "re:scala\\..*" + "re:javax\\..*" + "re:org\\.scala-lang\\..*" + "*" + ] + + importSelectorsOrder = Ascii + removeUnused = true + targetDialect = Scala3 + blankLines = Auto + expandRelative = false + groupSeparately = [ByTypeGivens] + groupedImports = Merge +} diff --git a/project/plugins.sbt b/project/plugins.sbt index 47b4a6e5..a1e890ba 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,4 +3,5 @@ addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.8.2") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.3.0") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.11.3") addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.3.1") +addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.14.7") libraryDependencies += "ai.kien" %% "python-native-libs" % "0.2.2"