diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index d16d665..c86b6af 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -252,6 +252,16 @@ object TensorOps: def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index)) def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy)) + // --- Mean --- + def nanmean: Tensor0[V] = Tensor0(Jax.jnp.nanmean(t.jaxValue)) + def nanmean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.index)) + def nanmean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.indices.toPythonProxy)) + + // --- Median --- + def nanmedian: Tensor0[V] = Tensor0(Jax.jnp.nanmedian(t.jaxValue)) + def nanmedian[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.index)) + def nanmedian[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.indices.toPythonProxy)) + end Reduction object Contraction: diff --git a/examples/src/main/scala/basic/KMeans.scala b/examples/src/main/scala/basic/KMeans.scala index 23c3e48..d3b2635 100644 --- a/examples/src/main/scala/basic/KMeans.scala +++ b/examples/src/main/scala/basic/KMeans.scala @@ -1,174 +1,210 @@ -package examples.basic +package examples.basic.kmeans import dimwit.* import dimwit.Conversions.given import dimwit.random.Random import dimwit.stats.Normal -/** Example implementation of the K-Means clustering algorithm, showcasing the use - * of Named Tensors for clear and type-safe code. - */ -class KMeans(nClusters: Int): +trait Point derives Label // N data points to cluster +trait Dim derives Label // D feature dimensions (coordinates of each point) +trait Cluster derives Label // K centroids — the result of clustering - import KMeans.{Point, Dim, Cluster} +trait CenterBasedClustering( + val nClusters: Int +): - /** Compute the squared Euclidean distance from each point to each centroid. - * @return A tensor where entry (p, c) is the squared distance from point p to centroid c. + /** Computes the distance between a point and a cluster center. + * @param point A single data point + * @param clusterCenter A single cluster center + * @return The distance between the point and the cluster center. */ - def squaredDistances( - data: Tensor2[Point, Dim, Float32], - centroids: Tensor2[Cluster, Dim, Float32] + def distance(point: Tensor1[Dim, Float32], clusterCenter: Tensor1[Dim, Float32]): Tensor0[Float32] + + /** Computes the new center given the clusterPoints + * @param clusterPoints Points assigned to a cluster, NaN for points not assigned to this cluster. + * @return The new cluster center computed from the assigned points. + */ + def calcCenter(clusterPoints: Tensor2[Point, Dim, Float32]): Tensor1[Dim, Float32] + + /** Compute the pairwise distances between each point and each center. + * @param points The data points + * @param center The cluster centers + * @return A tensor of shape [Point, Cluster] containing the distances from each point to each cluster center. + */ + def pairwiseDistances( + points: Tensor2[Point, Dim, Float32], + centers: Tensor2[Cluster, Dim, Float32] ): Tensor2[Point, Cluster, Float32] = - data.vmap(Axis[Point]) { point => - centroids.vmap(Axis[Cluster]) { centroid => - (point - centroid).pow(2f).sum - } - } + points.vmap(Axis[Point]): point => + centers.vmap(Axis[Cluster]): center => + distance(point, center) - /** Assign each point to the nearest centroid, returning a tensor of cluster indices. - * - * @param dists The distances between each point and its cluster - * @return A tensor where each entry is the index of the closest cluster for that point. + /** Assign each point to the nearest center. + * @param points The data points. + * @param center The cluster centers. */ - def assign( - dists: Tensor2[Point, Cluster, Float32] + def assignPoints( + points: Tensor2[Point, Dim, Float32], + centers: Tensor2[Cluster, Dim, Float32] ): Tensor1[Point, Int32] = - dists.argmin(Axis[Cluster]) + val distances = pairwiseDistances(points, centers) + distances.argmin(Axis[Cluster]) - /** Update the centroids by computing the mean of the points assigned to each cluster. + /** Update the centers by computing the mean of the points assigned to each cluster. * - * @param data The original data points + * @param points The data points * @param assignments The cluster index assigned to each point - * @param nPoints The total number of data points - * @param nClusters The total number of clusters (centroids) - * @return A tensor of updated centroids, one for each cluster. + * @return The new cluster centers */ - def updateCentroids( - data: Tensor2[Point, Dim, Float32], + private def calcCenters( + points: Tensor2[Point, Dim, Float32], assignments: Tensor1[Point, Int32] ): Tensor2[Cluster, Dim, Float32] = - val nPoints = data.shape.extent(Axis[Point]).size + val centers = (0 until nClusters).map: clusterIndex => + val clusterMembershipMask = assignments.elementEquals(Tensor.like(assignments).fill(clusterIndex)) + calcCenter(where(clusterMembershipMask.broadcastTo(points.shape), points, Tensor.like(points).fill(Float.NaN))) - val centroids = (0 until nClusters).map { c => + stack(centers, Axis[Cluster]) - val isMember = assignments.elementEquals(Tensor.like(assignments).fill(c)) - val belongs = isMember.asFloat32 // Bool -> Float32 for arithmetic + /** Perform one iteration of the EM algorithm: assign points to nearest centers (E-step) and then update centers (M-step). + * @param points The data points + * @param centers The current cluster centers + * @return The updated cluster centers + */ + private def emStep( + points: Tensor2[Point, Dim, Float32], + centers: Tensor2[Cluster, Dim, Float32] + ): Tensor2[Cluster, Dim, Float32] = + val assignments = assignPoints(points, centers) // E-step: Assign points to nearest centers + calcCenters(points, assignments) // M-step: Update centers based on assignments - val nMembers: Tensor0[Float32] = belongs.sum(Axis[Point]) - val clusterSum: Tensor1[Dim, Float32] = - zipvmap(Axis[Point])(belongs, data) { (b, point) => - point *! b - }.sum(Axis[Point]) + /** Run the trajectory of the K-Means algorithm + * @param points The data points + * @param centers The initial cluster centers (starting point) + * @return Cluster centers at each iteration, ending with the final converged centers. + */ + def run( + points: Tensor2[Point, Dim, Float32], + centers: Tensor2[Cluster, Dim, Float32] + ): LazyList[Tensor2[Cluster, Dim, Float32]] = + val nextCenters = emStep(points, centers) + val converged = nextCenters.elementEquals(centers).all.item + centers #:: (if converged then LazyList.empty else run(points, nextCenters)) + + /** Main entry point to run K-Means clustering. + * @param points The data points + * @param key A random key for random center initialization + * @param maxIterations Maximum number of iterations to run (in case of non-convergence) + * @return The final cluster centers (after convergence or reaching max iterations) + */ + def apply( + points: Tensor2[Point, Dim, Float32], + key: Random.Key, + maxIterations: Int + ): Tensor2[Cluster, Dim, Float32] = + val steps = run(points, initializeCenters(points, key)) + steps.take(maxIterations).last - clusterSum /! nMembers - } + /** Initialize cluster centers by randomly selecting points from the dataset. + * @param points The data points + * @param key A random key for random selection + * @return Initial cluster centers randomly chosen from the data points + */ + private def initializeCenters( + points: Tensor2[Point, Dim, Float32], + key: Random.Key + ): Tensor2[Cluster, Dim, Float32] = + val perm = Random.permutation(points.shape.extent(Axis[Point]))(key) + points.take(Axis[Point])(perm.slice(Axis[Point].at(0 until nClusters))).relabel(Axis[Point].as(Axis[Cluster])) - stack(centroids, Axis[Cluster]) +/** Example implementation of the K-Means clustering algorithm, showcasing the use + * of Named Tensors for clear and type-safe code. + */ +class KMeans(nClusters: Int) extends CenterBasedClustering(nClusters): - /** Run the K-Means algorithm on the given data, returning the final centroids and point assignments. - * @param data The input data points to cluster - * @param maxIter The maximum number of iterations to run - * @param seed The random seed for initialization - * @return The final centroids and point assignments - */ - def run( - data: Tensor2[Point, Dim, Float32], - maxIter: Int = 50, - seed: Int = 0 - ): (Tensor2[Cluster, Dim, Float32], Tensor1[Point, Int32]) = - - val nPoints = data.shape.extent(Axis[Point]).size - - val perm = Random.permutation(data.shape.extent(Axis[Point]))(Random.Key(seed)) - val initCentroids: Tensor2[Cluster, Dim, Float32] = - data - .take(Axis[Point])(perm.slice(Axis[Point].at(0 until nClusters))) - .relabel(Axis[Point].as(Axis[Cluster])) - - @annotation.tailrec - def loop( - centroids: Tensor2[Cluster, Dim, Float32], - prevAssignments: Option[Tensor1[Point, Int32]], - iter: Int - ): (Tensor2[Cluster, Dim, Float32], Tensor1[Point, Int32]) = - - val assignments = assign(squaredDistances(data, centroids)) - val converged = prevAssignments.exists(prev => (prev === assignments).item) - - if converged || iter >= maxIter then - (centroids, assignments) - else - loop( - updateCentroids(data, assignments), - Some(assignments), - iter + 1 - ) - - loop(initCentroids, None, 0) + override def distance(point: Tensor1[Dim, Float32], centroid: Tensor1[Dim, Float32]): Tensor0[Float32] = + (point - centroid).pow(2).sum -object KMeans: + override def calcCenter(clusterPoints: Tensor2[Point, Dim, Float32]): Tensor1[Dim, Float32] = + clusterPoints.nanmean(Axis[Point]) + +class KMedians(nClusters: Int) extends CenterBasedClustering(nClusters): + + override def distance(point: Tensor1[Dim, Float32], centroid: Tensor1[Dim, Float32]): Tensor0[Float32] = + (point - centroid).abs.sum - trait Point derives Label // N data points to cluster - trait Dim derives Label // D feature dimensions (coordinates of each point) - trait Cluster derives Label // K centroids — the result of clustering + override def calcCenter(clusterPoints: Tensor2[Point, Dim, Float32]): Tensor1[Dim, Float32] = + clusterPoints.nanmedian(Axis[Point]) + +object KMeans: // // Example usage // - @main def runKMeans(): Unit = - dimwit.initialize() + @main def run(): Unit = val n = 50 // points per cluster val k = 3 // clusters to discover - val rootKey = Random.Key(1337) - val (key0, rest) = rootKey.split2() - val (key1, key2) = rest.split2() - - val centers = Array( + val trueCenters = Array( Array(-3.0f, 0.0f), // cloud 0: lower-left Array(3.0f, 0.0f), // cloud 1: lower-right Array(0.0f, 4.0f) // cloud 2: top-center ) - def cloud(center: Array[Float], key: Random.Key): Tensor2[Point, Dim, Float32] = - val noise = Normal.standardNormal(Shape(Axis[Point] -> n, Axis[Dim] -> 2)).sample(key) - noise +! Tensor1(Axis[Dim]).fromArray(center) - val data: Tensor2[Point, Dim, Float32] = + def evaluate(alg: CenterBasedClustering, name: String, points: Tensor2[Point, Dim, Float32], trainKey: Random.Key): Unit = + val centers = alg(points, trainKey, maxIterations = 50) + val assignments = alg.assignPoints(points, centers) + + val inertia = alg.pairwiseDistances(points, centers).vmap(Axis[Point])(_.min).sum.item + + println(f"Inertia (within-cluster sum of distances): $inertia%.2f\n") + + // ── Report discovered clusters ───────────────────────────────────────────── + val nTotal = n * k + println(s"Discovered centers (${name}) vs true centers:") + println(f" ${"cluster"}%-9s ${"centroid (x,y)"}%-22s ${"true center"}%-16s size") + println(" " + "─" * 65) + + (0 until k).foreach: c => + // Extract the c-th centroid row: slice along the named Cluster axis. + val center = centers.slice(Axis[Cluster].at(c)) + val cx = center.slice(Axis[Dim].at(0)).item + val cy = center.slice(Axis[Dim].at(1)).item + + // Cluster size: count points whose assignment equals c. + // elementEquals produces [Point, Bool]; asFloat32 then sum gives a count. + val clusterSize = + assignments + .elementEquals(Tensor(Shape(Axis[Point] -> nTotal)).fill(c)) + .asFloat32.sum.item.toInt + + val tc = trueCenters(c) + println(f" cluster $c%-7d ($cx%+6.2f, $cy%+6.2f) (${tc(0)}%+6.2f, ${tc(1)}%+6.2f) $clusterSize") + + def prepareData(dataKey: Random.Key): Tensor2[Point, Dim, Float32] = + def cloud(center: Array[Float], key: Random.Key): Tensor2[Point, Dim, Float32] = + val noise = Normal.standardNormal(Shape(Axis[Point] -> n, Axis[Dim] -> center.length)).sample(key) + Tensor1(Axis[Dim]).fromArray(center) +! noise + + val (key0, key1, key2) = dataKey.splitToTuple(3) concatenate( - concatenate(cloud(centers(0), key0), cloud(centers(1), key1), Axis[Point]), - cloud(centers(2), key2), + List( + cloud(trueCenters(0), key0), + cloud(trueCenters(1), key1), + cloud(trueCenters(2), key2) + ), Axis[Point] ) + dimwit.initialize() + + val rootKey = Random.Key(1337) + val (dataKey, trainKey) = rootKey.split2() + + val points = prepareData(dataKey) + println(s"Dataset: ${n * k} points, 2 dimensions, $k true clusters\n") - val kmeans = KMeans(k) - val (centroids, assignments) = kmeans.run(data, maxIter = 30) - - val inertia: Float = - kmeans.squaredDistances(data, centroids).vmap(Axis[Point])(_.min).sum.item - - println(f"Inertia (within-cluster sum of squared distances): $inertia%.2f\n") - - // ── Report discovered clusters ───────────────────────────────────────────── - val nTotal = n * k - println("Discovered centroids (k-means) vs true centers:") - println(f" ${"cluster"}%-9s ${"centroid (x,y)"}%-22s ${"true center"}%-16s size") - println(" " + "─" * 65) - - (0 until k).foreach { c => - // Extract the c-th centroid row: slice along the named Cluster axis. - val centroid = centroids.slice(Axis[Cluster].at(c)) - val cx = centroid.slice(Axis[Dim].at(0)).item - val cy = centroid.slice(Axis[Dim].at(1)).item - - // Cluster size: count points whose assignment equals c. - // elementEquals produces [Point, Bool]; asFloat32 then sum gives a count. - val clusterSize = - assignments - .elementEquals(Tensor(Shape(Axis[Point] -> nTotal)).fill(c)) - .asFloat32.sum.item.toInt - - val tc = centers(c) - println(f" cluster $c%-7d ($cx%+6.2f, $cy%+6.2f) (${tc(0)}%+6.2f, ${tc(1)}%+6.2f) $clusterSize") - } + val (kmeansKey, kmediansKey) = trainKey.split2() + evaluate(new KMeans(k), "K-Means", points, kmeansKey) + evaluate(new KMedians(k), "K-Median", points, kmediansKey) diff --git a/examples/src/main/scala/basic/SIRSimulation.scala b/examples/src/main/scala/basic/SIRSimulation.scala index 0885966..8ccf4e3 100644 --- a/examples/src/main/scala/basic/SIRSimulation.scala +++ b/examples/src/main/scala/basic/SIRSimulation.scala @@ -9,50 +9,44 @@ import dimwit.autodiff.* object SIRSimulation: trait Time derives Label - trait InfectiousGroup derives Label - trait SusceptibleGroup derives Label - trait Compartment derives Label + trait Group derives Label - // We store all the compartments (S, I, R) in a separate tensor dimension (Compartment) and encode them - // using different IDs. - val SIndex = 0 - val IIndex = 1 - val RIndex = 2 + // Explicit state representation replaces the Compartment dimension + case class SIRState( + S: Tensor1[Group, Float32], + I: Tensor1[Group, Float32], + R: Tensor1[Group, Float32] + ): + lazy val N: Tensor1[Group, Float32] = S + I + R - /** One step of the simulation, according to the SIR model equations + /** One step of the simulation, according to the SIR model equations * - * @param state The current state of the system, with shape [SusceptibleGroup, Compartment] - * @param beta The infection rate matrix (with entry beta[h, g] controlling how strongly infectious - * individuals in group h infect susceptible individuals in group g) - * @param gamma The recovery rate - * @param dt The time step - * @return The next state of the system + * @param state The current state of the system + * @param beta The infection rate matrix (with entry beta[h, g] controlling how strongly infectious + * individuals in group h infect susceptible individuals in group g) + * @param gamma The recovery rate + * @param dt The time step + * @return The next state of the system */ def step( - state: Tensor2[SusceptibleGroup, Compartment, Float32], - beta: Tensor2[InfectiousGroup, SusceptibleGroup, Float32], + beta: Tensor2[Group, Prime[Group], Float32], gamma: Tensor0[Float32], dt: Tensor0[Float32] - ): Tensor2[SusceptibleGroup, Compartment, Float32] = + )(state: SIRState): SIRState = + import state.{S, I, R, N} - val S = state.slice(Axis[Compartment].at(SIndex)) - val I = state.slice(Axis[Compartment].at(IIndex)) - val R = state.slice(Axis[Compartment].at(RIndex)) - - val N = S + I + R // All individuals in the population - - val infectiousFraction = (I / N).relabel(Axis[SusceptibleGroup].as(Axis[InfectiousGroup])) - val force = infectiousFraction.dot(Axis[InfectiousGroup])(beta) - val newInfections = S * force + val infectiousFraction = I / N + val force = infectiousFraction.dot(Axis[Group])(beta) + val transmissions = S * force.dropPrimes val recoveries = I *! gamma // compute next state - val SNext = S - newInfections *! dt - val INext = I + (newInfections - recoveries) *! dt + val SNext = S - transmissions *! dt + val INext = I + (transmissions - recoveries) *! dt val RNext = R + recoveries *! dt - stack(Seq(SNext, INext, RNext), Axis[Compartment]).transpose + SIRState(SNext, INext, RNext) /** run n steps of the simulation, starting from the initial state * @@ -64,59 +58,45 @@ object SIRSimulation: * @return The trajectory of the system over time */ def simulate( - initial: Tensor2[SusceptibleGroup, Compartment, Float32], - beta: Tensor2[InfectiousGroup, SusceptibleGroup, Float32], + initial: SIRState, + beta: Tensor2[Group, Prime[Group], Float32], gamma: Tensor0[Float32], - dt: Tensor0[Float32], - nSteps: Int - ): Tensor3[Time, SusceptibleGroup, Compartment, Float32] = - - val states: IndexedSeq[Tensor2[SusceptibleGroup, Compartment, Float32]] = - (0 until nSteps).scanLeft(initial): (state, _) => - step(state, beta, gamma, dt) - - stack(states, Axis[Time]) + dt: Tensor0[Float32] + ): LazyList[SIRState] = + LazyList.iterate(initial)(step(beta, gamma, dt)) @main def runSIRSimulation(): Unit = dimwit.initialize() - val susceptibleGroupDim = Axis[SusceptibleGroup] -> 3 - val infectiousGroupDim = Axis[InfectiousGroup] -> 3 - val compartmentDim = Axis[Compartment] -> 3 + val groupDim = Axis[Group] -> 3 - /* - * Three groups, coded by (0, 1, 2), which is e.g. - * children, adults, and elderly. - */ - val initial: Tensor2[SusceptibleGroup, Compartment, Float32] = - Tensor(Shape(Axis[SusceptibleGroup] -> 3, Axis[Compartment] -> 3)).fromFunction(index => - (index(Axis[SusceptibleGroup]), index(Axis[Compartment])) match - // children - case (0, 0) => 990f - case (0, 1) => 10f - case (0, 2) => 0f - - // adults - case (1, 0) => 1995f - case (1, 1) => 5f - case (1, 2) => 0f - - // elderly - case (2, 0) => 1500f - case (2, 1) => 0f - case (2, 2) => 0f - - case _ => 0f - ) + // Setup initial state of the system + + val initialS = Tensor(Shape(groupDim)).fromFunction(index => + index(Axis[Group]) match + case 0 => 990f + case 1 => 1995f + case 2 => 1500f + ) + val initialI = Tensor(Shape(groupDim)).fromFunction(index => + index(Axis[Group]) match + case 0 => 10f + case 1 => 5f + case 2 => 0f + ) + val initialR = Tensor(Shape(groupDim)).fromFunction(_ => 0f) + val initial = SIRState(initialS, initialI, initialR) + + // Setup parameters for the simulation /* - * * beta(h, g) controls how strongly infectious individuals in group h * infect susceptible individuals in group g. */ - val beta: Tensor2[InfectiousGroup, SusceptibleGroup, Float32] = - Tensor(Shape(Axis[InfectiousGroup] -> 3, Axis[SusceptibleGroup] -> 3)).fromFunction(index => - (index(Axis[InfectiousGroup]), index(Axis[SusceptibleGroup])) match // infectious children -> susceptible children/adults/elderly + val beta: Tensor2[Group, Prime[Group], Float32] = + Tensor(Shape(groupDim, Axis[Prime[Group]] -> groupDim.size)).fromFunction(index => + (index(Axis[Group]), index(Axis[Prime[Group]])) match + // infectious children -> susceptible children/adults/elderly case (0, 0) => 0.40f case (0, 1) => 0.20f case (0, 2) => 0.10f @@ -130,42 +110,31 @@ object SIRSimulation: case (2, 0) => 0.10f case (2, 1) => 0.15f case (2, 2) => 0.20f - case _ => 0f - ) + case (_, _) => throw new IllegalArgumentException("Invalid group indices") + ) val gamma = Tensor0(0.1f) val dt = Tensor0(0.1f) - val nSteps = 160 + // Run the simulation + + val nSteps = 160 val trajectory = SIRSimulation.simulate( initial = initial, beta = beta, gamma = gamma, - dt = dt, - nSteps = nSteps - ) + dt = dt + ).take(nSteps + 1).toList - /* - * Total infected population over time: - * - * I_total(t) = sum_g I_g(t) - */ - val infectedOverTime: Tensor1[Time, Float32] = - trajectory - .slice(Axis[Compartment].at(SIRSimulation.IIndex)) - .sum(Axis[SusceptibleGroup]) + // Report the results + val infectedOverTime: Tensor1[Time, Float32] = + stack(trajectory.map(_.I), Axis[Time]).sum(Axis[Group]) println(s"I(0) = ${infectedOverTime.slice(Axis[Time].at(0))}") println(s"I(mid) = ${infectedOverTime.slice(Axis[Time].at(nSteps / 2))}") println(s"I(end) = ${infectedOverTime.slice(Axis[Time].at(nSteps))}") - /* - * Infected population in each group at final time: - */ - val finalInfectedByGroup: Tensor1[SusceptibleGroup, Float32] = - trajectory - .slice(Axis[Time].at(nSteps)) - .slice(Axis[Compartment].at(SIRSimulation.IIndex)) - + val finalInfectedByGroup: Tensor1[Group, Float32] = + stack(trajectory.map(_.I), Axis[Time]).slice(Axis[Time].at(nSteps)) println(s"Final infected by group: $finalInfectedByGroup")