From 5b039a8e7868ba2948469327d281634fb2b250f8 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Thu, 18 Jun 2026 13:16:01 +0200 Subject: [PATCH] add two basic examples that illustrate named tensors Example 1: K-Means Example 2: SIR-Simulation --- examples/src/main/scala/basic/KMeans.scala | 174 ++++++++++++++++++ .../src/main/scala/basic/SIRSimulation.scala | 171 +++++++++++++++++ 2 files changed, 345 insertions(+) create mode 100644 examples/src/main/scala/basic/KMeans.scala create mode 100644 examples/src/main/scala/basic/SIRSimulation.scala diff --git a/examples/src/main/scala/basic/KMeans.scala b/examples/src/main/scala/basic/KMeans.scala new file mode 100644 index 0000000..23c3e48 --- /dev/null +++ b/examples/src/main/scala/basic/KMeans.scala @@ -0,0 +1,174 @@ +package examples.basic + +import dimwit.* +import dimwit.Conversions.given +import dimwit.random.Random +import dimwit.stats.Normal + +/** Example implementation of the K-Means clustering algorithm, showcasing the use + * of Named Tensors for clear and type-safe code. + */ +class KMeans(nClusters: Int): + + import KMeans.{Point, Dim, Cluster} + + /** Compute the squared Euclidean distance from each point to each centroid. + * @return A tensor where entry (p, c) is the squared distance from point p to centroid c. + */ + def squaredDistances( + data: Tensor2[Point, Dim, Float32], + centroids: Tensor2[Cluster, Dim, Float32] + ): Tensor2[Point, Cluster, Float32] = + data.vmap(Axis[Point]) { point => + centroids.vmap(Axis[Cluster]) { centroid => + (point - centroid).pow(2f).sum + } + } + + /** Assign each point to the nearest centroid, returning a tensor of cluster indices. + * + * @param dists The distances between each point and its cluster + * @return A tensor where each entry is the index of the closest cluster for that point. + */ + def assign( + dists: Tensor2[Point, Cluster, Float32] + ): Tensor1[Point, Int32] = + dists.argmin(Axis[Cluster]) + + /** Update the centroids by computing the mean of the points assigned to each cluster. + * + * @param data The original data points + * @param assignments The cluster index assigned to each point + * @param nPoints The total number of data points + * @param nClusters The total number of clusters (centroids) + * @return A tensor of updated centroids, one for each cluster. + */ + def updateCentroids( + data: Tensor2[Point, Dim, Float32], + assignments: Tensor1[Point, Int32] + ): Tensor2[Cluster, Dim, Float32] = + val nPoints = data.shape.extent(Axis[Point]).size + + val centroids = (0 until nClusters).map { c => + + val isMember = assignments.elementEquals(Tensor.like(assignments).fill(c)) + val belongs = isMember.asFloat32 // Bool -> Float32 for arithmetic + + val nMembers: Tensor0[Float32] = belongs.sum(Axis[Point]) + val clusterSum: Tensor1[Dim, Float32] = + zipvmap(Axis[Point])(belongs, data) { (b, point) => + point *! b + }.sum(Axis[Point]) + + clusterSum /! nMembers + } + + stack(centroids, Axis[Cluster]) + + /** Run the K-Means algorithm on the given data, returning the final centroids and point assignments. + * @param data The input data points to cluster + * @param maxIter The maximum number of iterations to run + * @param seed The random seed for initialization + * @return The final centroids and point assignments + */ + def run( + data: Tensor2[Point, Dim, Float32], + maxIter: Int = 50, + seed: Int = 0 + ): (Tensor2[Cluster, Dim, Float32], Tensor1[Point, Int32]) = + + val nPoints = data.shape.extent(Axis[Point]).size + + val perm = Random.permutation(data.shape.extent(Axis[Point]))(Random.Key(seed)) + val initCentroids: Tensor2[Cluster, Dim, Float32] = + data + .take(Axis[Point])(perm.slice(Axis[Point].at(0 until nClusters))) + .relabel(Axis[Point].as(Axis[Cluster])) + + @annotation.tailrec + def loop( + centroids: Tensor2[Cluster, Dim, Float32], + prevAssignments: Option[Tensor1[Point, Int32]], + iter: Int + ): (Tensor2[Cluster, Dim, Float32], Tensor1[Point, Int32]) = + + val assignments = assign(squaredDistances(data, centroids)) + val converged = prevAssignments.exists(prev => (prev === assignments).item) + + if converged || iter >= maxIter then + (centroids, assignments) + else + loop( + updateCentroids(data, assignments), + Some(assignments), + iter + 1 + ) + + loop(initCentroids, None, 0) + +object KMeans: + + trait Point derives Label // N data points to cluster + trait Dim derives Label // D feature dimensions (coordinates of each point) + trait Cluster derives Label // K centroids — the result of clustering + + // + // Example usage + // + @main def runKMeans(): Unit = + dimwit.initialize() + + val n = 50 // points per cluster + val k = 3 // clusters to discover + + val rootKey = Random.Key(1337) + val (key0, rest) = rootKey.split2() + val (key1, key2) = rest.split2() + + val centers = Array( + Array(-3.0f, 0.0f), // cloud 0: lower-left + Array(3.0f, 0.0f), // cloud 1: lower-right + Array(0.0f, 4.0f) // cloud 2: top-center + ) + def cloud(center: Array[Float], key: Random.Key): Tensor2[Point, Dim, Float32] = + val noise = Normal.standardNormal(Shape(Axis[Point] -> n, Axis[Dim] -> 2)).sample(key) + noise +! Tensor1(Axis[Dim]).fromArray(center) + + val data: Tensor2[Point, Dim, Float32] = + concatenate( + concatenate(cloud(centers(0), key0), cloud(centers(1), key1), Axis[Point]), + cloud(centers(2), key2), + Axis[Point] + ) + + println(s"Dataset: ${n * k} points, 2 dimensions, $k true clusters\n") + val kmeans = KMeans(k) + val (centroids, assignments) = kmeans.run(data, maxIter = 30) + + val inertia: Float = + kmeans.squaredDistances(data, centroids).vmap(Axis[Point])(_.min).sum.item + + println(f"Inertia (within-cluster sum of squared distances): $inertia%.2f\n") + + // ── Report discovered clusters ───────────────────────────────────────────── + val nTotal = n * k + println("Discovered centroids (k-means) vs true centers:") + println(f" ${"cluster"}%-9s ${"centroid (x,y)"}%-22s ${"true center"}%-16s size") + println(" " + "─" * 65) + + (0 until k).foreach { c => + // Extract the c-th centroid row: slice along the named Cluster axis. + val centroid = centroids.slice(Axis[Cluster].at(c)) + val cx = centroid.slice(Axis[Dim].at(0)).item + val cy = centroid.slice(Axis[Dim].at(1)).item + + // Cluster size: count points whose assignment equals c. + // elementEquals produces [Point, Bool]; asFloat32 then sum gives a count. + val clusterSize = + assignments + .elementEquals(Tensor(Shape(Axis[Point] -> nTotal)).fill(c)) + .asFloat32.sum.item.toInt + + val tc = centers(c) + println(f" cluster $c%-7d ($cx%+6.2f, $cy%+6.2f) (${tc(0)}%+6.2f, ${tc(1)}%+6.2f) $clusterSize") + } diff --git a/examples/src/main/scala/basic/SIRSimulation.scala b/examples/src/main/scala/basic/SIRSimulation.scala new file mode 100644 index 0000000..0885966 --- /dev/null +++ b/examples/src/main/scala/basic/SIRSimulation.scala @@ -0,0 +1,171 @@ +package src.main.scala.basic + +import dimwit.* +import dimwit.Conversions.given +import dimwit.autodiff.* + +/** A simple SIR (Susceptible-Infectious-Recovered) simulation. + */ +object SIRSimulation: + + trait Time derives Label + trait InfectiousGroup derives Label + trait SusceptibleGroup derives Label + trait Compartment derives Label + + // We store all the compartments (S, I, R) in a separate tensor dimension (Compartment) and encode them + // using different IDs. + val SIndex = 0 + val IIndex = 1 + val RIndex = 2 + + /** One step of the simulation, according to the SIR model equations + * + * @param state The current state of the system, with shape [SusceptibleGroup, Compartment] + * @param beta The infection rate matrix (with entry beta[h, g] controlling how strongly infectious + * individuals in group h infect susceptible individuals in group g) + * @param gamma The recovery rate + * @param dt The time step + * @return The next state of the system + */ + def step( + state: Tensor2[SusceptibleGroup, Compartment, Float32], + beta: Tensor2[InfectiousGroup, SusceptibleGroup, Float32], + gamma: Tensor0[Float32], + dt: Tensor0[Float32] + ): Tensor2[SusceptibleGroup, Compartment, Float32] = + + val S = state.slice(Axis[Compartment].at(SIndex)) + val I = state.slice(Axis[Compartment].at(IIndex)) + val R = state.slice(Axis[Compartment].at(RIndex)) + + val N = S + I + R // All individuals in the population + + val infectiousFraction = (I / N).relabel(Axis[SusceptibleGroup].as(Axis[InfectiousGroup])) + val force = infectiousFraction.dot(Axis[InfectiousGroup])(beta) + val newInfections = S * force + + val recoveries = I *! gamma + + // compute next state + val SNext = S - newInfections *! dt + val INext = I + (newInfections - recoveries) *! dt + val RNext = R + recoveries *! dt + + stack(Seq(SNext, INext, RNext), Axis[Compartment]).transpose + + /** run n steps of the simulation, starting from the initial state + * + * @param initial The initial state of the system + * @param beta @see [[step]] + * @param gamma @see [[step]] + * @param dt @see [[step]] + * @param nSteps The number of steps to simulate + * @return The trajectory of the system over time + */ + def simulate( + initial: Tensor2[SusceptibleGroup, Compartment, Float32], + beta: Tensor2[InfectiousGroup, SusceptibleGroup, Float32], + gamma: Tensor0[Float32], + dt: Tensor0[Float32], + nSteps: Int + ): Tensor3[Time, SusceptibleGroup, Compartment, Float32] = + + val states: IndexedSeq[Tensor2[SusceptibleGroup, Compartment, Float32]] = + (0 until nSteps).scanLeft(initial): (state, _) => + step(state, beta, gamma, dt) + + stack(states, Axis[Time]) + + @main def runSIRSimulation(): Unit = + dimwit.initialize() + + val susceptibleGroupDim = Axis[SusceptibleGroup] -> 3 + val infectiousGroupDim = Axis[InfectiousGroup] -> 3 + val compartmentDim = Axis[Compartment] -> 3 + + /* + * Three groups, coded by (0, 1, 2), which is e.g. + * children, adults, and elderly. + */ + val initial: Tensor2[SusceptibleGroup, Compartment, Float32] = + Tensor(Shape(Axis[SusceptibleGroup] -> 3, Axis[Compartment] -> 3)).fromFunction(index => + (index(Axis[SusceptibleGroup]), index(Axis[Compartment])) match + // children + case (0, 0) => 990f + case (0, 1) => 10f + case (0, 2) => 0f + + // adults + case (1, 0) => 1995f + case (1, 1) => 5f + case (1, 2) => 0f + + // elderly + case (2, 0) => 1500f + case (2, 1) => 0f + case (2, 2) => 0f + + case _ => 0f + ) + + /* + * + * beta(h, g) controls how strongly infectious individuals in group h + * infect susceptible individuals in group g. + */ + val beta: Tensor2[InfectiousGroup, SusceptibleGroup, Float32] = + Tensor(Shape(Axis[InfectiousGroup] -> 3, Axis[SusceptibleGroup] -> 3)).fromFunction(index => + (index(Axis[InfectiousGroup]), index(Axis[SusceptibleGroup])) match // infectious children -> susceptible children/adults/elderly + case (0, 0) => 0.40f + case (0, 1) => 0.20f + case (0, 2) => 0.10f + + // infectious adults -> susceptible children/adults/elderly + case (1, 0) => 0.20f + case (1, 1) => 0.30f + case (1, 2) => 0.15f + + // infectious elderly -> susceptible children/adults/elderly + case (2, 0) => 0.10f + case (2, 1) => 0.15f + case (2, 2) => 0.20f + case _ => 0f + ) + + val gamma = Tensor0(0.1f) + val dt = Tensor0(0.1f) + val nSteps = 160 + + val trajectory = + SIRSimulation.simulate( + initial = initial, + beta = beta, + gamma = gamma, + dt = dt, + nSteps = nSteps + ) + + /* + * Total infected population over time: + * + * I_total(t) = sum_g I_g(t) + */ + val infectedOverTime: Tensor1[Time, Float32] = + trajectory + .slice(Axis[Compartment].at(SIRSimulation.IIndex)) + .sum(Axis[SusceptibleGroup]) + + println(s"I(0) = ${infectedOverTime.slice(Axis[Time].at(0))}") + println(s"I(mid) = ${infectedOverTime.slice(Axis[Time].at(nSteps / 2))}") + println(s"I(end) = ${infectedOverTime.slice(Axis[Time].at(nSteps))}") + + /* + * Infected population in each group at final time: + */ + val finalInfectedByGroup: Tensor1[SusceptibleGroup, Float32] = + trajectory + .slice(Axis[Time].at(nSteps)) + .slice(Axis[Compartment].at(SIRSimulation.IIndex)) + + println(s"Final infected by group: $finalInfectedByGroup")