Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ ThisBuild / developers := List(
)
)

lazy val commonScalacOptions = Seq(
"-deprecation",
"-unchecked",
"-Wunused:imports",
"-explain-cyclic"
)

// Setup for Scalafix and SemanticDB
inThisBuild(Seq(
semanticdbEnabled := true,
semanticdbVersion := scalafixSemanticdb.revision
))

ThisBuild / scalafixDependencies +=
"com.github.liancheng" %% "organize-imports" % "0.6.0"

addCommandAlias("testAndCoverage", "; clean; coverage; test; coverageReport")

lazy val root = (project in file("."))
Expand Down Expand Up @@ -66,7 +82,8 @@ lazy val core = (project in file("core"))
coverageFailOnMinimum := false,
coverageHighlighting := true,
Compile / packageSrc / publishArtifact := true,
Compile / packageDoc / publishArtifact := true
Compile / packageDoc / publishArtifact := true,
scalacOptions ++= commonScalacOptions
)

// Examples subproject
Expand All @@ -87,6 +104,7 @@ lazy val examples = (project in file("examples"))
// Examples source directory
Compile / scalaSource := baseDirectory.value,
Compile / resourceDirectory := baseDirectory.value / "src" / "main" / "resources",
scalacOptions ++= commonScalacOptions,
scalafmtFailOnErrors := false,
javaOptions ++= {
if (sys.props("os.name").toLowerCase.contains("mac")) {
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/dimwit/MemoryHelper.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package dimwit

import me.shadaj.scalapy.py
import dimwit.autodiff.TensorTree
import me.shadaj.scalapy.py

private[dimwit] object MemoryHelper:

def withLocalCleanup(f: => Unit): Unit =
private[dimwit] def withLocalCleanupImpl(f: => Unit): Unit =
py.local:
f

def withLocalCleanup[A: TensorTree](f: => A): A =
private[dimwit] def withLocalCleanupImpl[A: TensorTree](f: => A): A =
val lifeRaft = me.shadaj.scalapy.py.Dynamic.global.list()
py.local:
val res = f
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/dimwit/OnError.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dimwit

import java.io.StringWriter
import java.io.PrintWriter
import java.io.StringWriter

object OnError:

Expand Down
9 changes: 4 additions & 5 deletions core/src/main/scala/dimwit/autodiff/Autodiff.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package dimwit.autodiff

import dimwit.OnError
import dimwit.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Shape}
import dimwit.tensor.ShapeTypeHelpers.AxisIndices
import dimwit.tensor.TupleHelpers.PrimeConcatType
import dimwit.jax.Jax
import me.shadaj.scalapy.py
import dimwit.tensor.Label
import dimwit.tensor.Tensor
import dimwit.tensor.Tensor0
import dimwit.tensor.TensorOps.IsFloating
import dimwit.tensor.TupleHelpers.PrimeConcatType
import me.shadaj.scalapy.py

object Autodiff:

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/dimwit/autodiff/FloatTree.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package dimwit.autodiff

import dimwit.tensor.*
import dimwit.tensor.TensorOps.*
import dimwit.tensor.*

import scala.deriving.*
import scala.compiletime.*
import scala.util.NotGiven

/** A marker trait for structures that are trees of floating-point tensors.
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/dimwit/autodiff/Grad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dimwit.autodiff

import dimwit.*
import dimwit.jax.Jax

import scala.deriving.Mirror

/** Type-level tag marking a parameter structure as gradients.
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/dimwit/autodiff/TensorTree.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package dimwit.autodiff

import dimwit.jax.Jax
import dimwit.tensor.*
import dimwit.random.Random
import scala.deriving.*
import scala.compiletime.*
import scala.deriving.*

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import dimwit.jax.Jax

/** A typeclass for structures that can be represented as a tree of tensors,
* which can be mapped over. Most often, a tensor tree is used to structure
Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/dimwit/hardware/Device.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
package dimwit.hardware

import dimwit.*
import dimwit.tensor.TupleHelpers
import dimwit.jax.Jax
import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.readwrite.Writer
import me.shadaj.scalapy.interpreter.PyValue

case class Device private[dimwit] (private[dimwit] val jaxDevice: py.Dynamic):

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/dimwit/hardware/DeviceBackend.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dimwit.hardware

import me.shadaj.scalapy.py
import dimwit.jax.Jax
import me.shadaj.scalapy.py

enum DeviceBackend(private[dimwit] val jaxName: String):

Expand Down
1 change: 0 additions & 1 deletion core/src/main/scala/dimwit/jax/Einops.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dimwit.jax

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.py.PyQuote

object Einops:

Expand Down
7 changes: 3 additions & 4 deletions core/src/main/scala/dimwit/jax/Jax.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package dimwit.jax

import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.py.PyQuote
import dimwit.hardware.{Device, DeviceBackend}
import dimwit.hardware.Device
import dimwit.hardware.DeviceBackend
import dimwit.python.PythonSetup
import me.shadaj.scalapy.py

object Jax:

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/dimwit/jax/JaxDType.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dimwit.jax

import me.shadaj.scalapy.py
import dimwit.tensor.DType
import me.shadaj.scalapy.py

object JaxDType:

Expand Down
11 changes: 5 additions & 6 deletions core/src/main/scala/dimwit/jax/Jit.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package dimwit.jax

import dimwit.tensor.{Tensor, Shape, Labels}
import dimwit.jax.{Jax, JaxDType}
import dimwit.OnError
import dimwit.autodiff.TensorTree
import dimwit.jax.Jax
import dimwit.jax.Jax.PyDynamic
import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import dimwit.jax.Jax.PyDynamic
import me.shadaj.scalapy.py.PythonException
import dimwit.OnError

import scala.annotation.targetName

object Jit:
Expand Down Expand Up @@ -142,7 +141,7 @@ object JitDefault:

object EagerCleanup:

import dimwit.MemoryHelper.withLocalCleanup
import dimwit.withLocalCleanup

def eagerCleanup[T1, R: TensorTree](f: T1 => R): T1 => R = (t1) =>
withLocalCleanup:
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/dimwit/nn/ActivationFunctions.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package dimwit.nn

import dimwit.tensor.*
import dimwit.tensor.TensorOps.IsFloating
import dimwit.jax.Jax
import dimwit.python.PyBridge.{liftPyTensor, toPyTensor}
import dimwit.python.PyBridge.liftPyTensor
import dimwit.python.PyBridge.toPyTensor
import dimwit.tensor.TensorOps.IsFloating
import dimwit.tensor.*

object ActivationFunctions:

Expand Down
4 changes: 1 addition & 3 deletions core/src/main/scala/dimwit/optimizer/GradientOptimizer.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package dimwit.optimizer

import dimwit.*
import dimwit.autodiff.FloatTree.ops.*
import dimwit.autodiff.FloatTree.*
import dimwit.autodiff.FloatTree.ops.*
import dimwit.autodiff.*
import dimwit.jax.Jax
import dimwit.jax.Jit

/** Gradient optimizer interface with functional state management.
*
Expand Down
23 changes: 19 additions & 4 deletions core/src/main/scala/dimwit/package.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import scala.annotation.targetName

import dimwit.jax.Jax
import dimwit.tensor.{Axis, AxisExtent, AxisSelector, AxisAtIndex, AxisAtRange, AxisAtIndices, AxisAtTensorIndex}
import dimwit.tensor.Axis
import dimwit.tensor.AxisAtIndex
import dimwit.tensor.AxisAtIndices
import dimwit.tensor.AxisAtRange
import dimwit.tensor.AxisAtTensorIndex
import dimwit.tensor.AxisExtent
import dimwit.tensor.AxisSelector

import scala.annotation.targetName
package object dimwit:

import scala.compiletime.ops.string.+
Expand Down Expand Up @@ -93,7 +99,16 @@ package object dimwit:
// export some stats types
export dimwit.stats.{Prob, LogProb}
export dimwit.stats.{Distribution, IndependentDistribution, MultivariateDistribution, UnivariateDistribution}
export dimwit.MemoryHelper.withLocalCleanup

/** Memory management helpfer making sure
* all python objects allocated ar freed
* after the function is executed.
*/
def withLocalCleanup(f: => Unit): Unit =
MemoryHelper.withLocalCleanupImpl(f)

def withLocalCleanup[A: TensorTree](f: => A): A =
MemoryHelper.withLocalCleanupImpl(f)

/** Explicitly configures the Python environment before any ScalaPy call.
* Call this function at the start of your program (before any `py.*` call)
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/dimwit/python/PyBridge.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package dimwit.python

import dimwit.tensor.*
import dimwit.jax.Jax
import dimwit.autodiff.TensorTree
import dimwit.OnError
import dimwit.autodiff.TensorTree
import dimwit.jax.Jax
import dimwit.tensor.*
import me.shadaj.scalapy.py

object PyBridge:
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/dimwit/python/PythonSetup.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package dimwit.python

import me.shadaj.scalapy.py
import scala.sys.process.{Process, ProcessLogger}

import scala.sys.process.Process
import scala.sys.process.ProcessLogger

/** Manages Python environment setup for DimWit.
*
Expand Down
13 changes: 6 additions & 7 deletions core/src/main/scala/dimwit/random/Random.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package dimwit.random

import dimwit.tensor.*
import dimwit.tensor.DType.Int32
import dimwit.tensor.TensorOps.*
import dimwit.jax.{Jax, JaxDType}
import dimwit.autodiff.TensorTree
import me.shadaj.scalapy.py.SeqConverters
import dimwit.jax.Jax
import dimwit.python.PyBridge.liftPyTensor
import scala.compiletime.{requireConst, constValue, ops}
import Tuple.Size
import dimwit.tensor.DType.Int32
import dimwit.tensor.TensorOps.*
import dimwit.tensor.TupleHelpers.TupleNOf
import dimwit.tensor.*

import scala.compiletime.requireConst

/** JAX-based random number generation with proper key management.
*
Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/dimwit/stats/Distributions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package dimwit.stats

import dimwit.*
import dimwit.random.Random
import dimwit.jax.Jax
import dimwit.jax.Jax.scipy_stats as jstats
import dimwit.jax.Jax.PyDynamic
import dimwit.tensor.TensorOps

opaque type LogProb = Float32
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package dimwit.stats

import dimwit.*
import dimwit.DType.Float32
import dimwit.jax.Jax.scipy_stats as jstats
import dimwit.*
import dimwit.jax.Jax
import dimwit.jax.Jax.PyDynamic
import dimwit.jax.Jax.scipy_stats as jstats
import dimwit.python.PyBridge.liftPyTensor
import dimwit.random.Random
import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import dimwit.random.Random
import dimwit.python.PyBridge.liftPyTensor

/** Normal (Gaussian) distribution */
class Normal[T <: Tuple: Labels, V: IsFloating](val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package dimwit.stats

import dimwit.*
import dimwit.random.Random
import dimwit.jax.Jax
import dimwit.jax.Jax.scipy_stats as jstats
import dimwit.jax.Jax.PyDynamic
import dimwit.python.PyBridge.liftPyTensor
import me.shadaj.scalapy.readwrite.Reader
import dimwit.random.Random

/** Distribution over a vector of random variables.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package dimwit.stats

import dimwit.DType.Float32
import dimwit.DType.Int32
import dimwit.*
import dimwit.DType.{Int32, Float32}
import dimwit.random.*
import dimwit.jax.Jax
import dimwit.python.PyBridge.liftPyTensor

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/dimwit/tensor/ArrayReader.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package dimwit.tensor

import me.shadaj.scalapy.py

import java.nio.ByteBuffer
import java.nio.ByteOrder
import me.shadaj.scalapy.py

/** Utility object for reading flat arrays of scalar values from JAX tensors.
*/
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/dimwit/tensor/ArrayWriter.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package dimwit.tensor

import java.nio.ByteBuffer
import java.util.Base64
import java.nio.ByteOrder
import dimwit.jax.Jax
import dimwit.tensor.TensorOps.IsBoolean
import dimwit.tensor.TensorOps.IsFloating
import dimwit.tensor.TensorOps.IsInteger
import me.shadaj.scalapy.py
import me.shadaj.scalapy.py.SeqConverters
import me.shadaj.scalapy.readwrite.Writer
import me.shadaj.scalapy.interpreter.PyValue
import dimwit.jax.Jax
import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating}

import java.util.Base64

object ArrayWriter:

Expand Down
Loading
Loading