Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/src/main/scala/dimwit/tensor/TensorOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,16 @@ object TensorOps:
def median[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.index))
def median[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.median(t.jaxValue, axis = ev.indices.toPythonProxy))

// --- Mean ---
def nanmean: Tensor0[V] = Tensor0(Jax.jnp.nanmean(t.jaxValue))
def nanmean[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.index))
def nanmean[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmean(t.jaxValue, axis = ev.indices.toPythonProxy))

// --- Median ---
def nanmedian: Tensor0[V] = Tensor0(Jax.jnp.nanmedian(t.jaxValue))
def nanmedian[L: Label](axis: Axis[L])(using ev: AxisRemover[T, L], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.index))
def nanmedian[Inputs <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs]], l: Labels[ev.RemainingAxes]): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.nanmedian(t.jaxValue, axis = ev.indices.toPythonProxy))

end Reduction

object Contraction:
Expand Down
294 changes: 165 additions & 129 deletions examples/src/main/scala/basic/KMeans.scala
Original file line number Diff line number Diff line change
@@ -1,174 +1,210 @@
package examples.basic
package examples.basic.kmeans

import dimwit.*
import dimwit.Conversions.given
import dimwit.random.Random
import dimwit.stats.Normal

/** Example implementation of the K-Means clustering algorithm, showcasing the use
* of Named Tensors for clear and type-safe code.
*/
class KMeans(nClusters: Int):
trait Point derives Label // N data points to cluster
trait Dim derives Label // D feature dimensions (coordinates of each point)
trait Cluster derives Label // K centroids — the result of clustering

import KMeans.{Point, Dim, Cluster}
trait CenterBasedClustering(
val nClusters: Int
):

/** Compute the squared Euclidean distance from each point to each centroid.
* @return A tensor where entry (p, c) is the squared distance from point p to centroid c.
/** Computes the distance between a point and a cluster center.
* @param point A single data point
* @param clusterCenter A single cluster center
* @return The distance between the point and the cluster center.
*/
def squaredDistances(
data: Tensor2[Point, Dim, Float32],
centroids: Tensor2[Cluster, Dim, Float32]
def distance(point: Tensor1[Dim, Float32], clusterCenter: Tensor1[Dim, Float32]): Tensor0[Float32]

/** Computes the new center given the clusterPoints
* @param clusterPoints Points assigned to a cluster, NaN for points not assigned to this cluster.
* @return The new cluster center computed from the assigned points.
*/
def calcCenter(clusterPoints: Tensor2[Point, Dim, Float32]): Tensor1[Dim, Float32]

/** Compute the pairwise distances between each point and each center.
* @param points The data points
* @param center The cluster centers
* @return A tensor of shape [Point, Cluster] containing the distances from each point to each cluster center.
*/
def pairwiseDistances(
points: Tensor2[Point, Dim, Float32],
centers: Tensor2[Cluster, Dim, Float32]
): Tensor2[Point, Cluster, Float32] =
data.vmap(Axis[Point]) { point =>
centroids.vmap(Axis[Cluster]) { centroid =>
(point - centroid).pow(2f).sum
}
}
points.vmap(Axis[Point]): point =>
centers.vmap(Axis[Cluster]): center =>
distance(point, center)

/** Assign each point to the nearest centroid, returning a tensor of cluster indices.
*
* @param dists The distances between each point and its cluster
* @return A tensor where each entry is the index of the closest cluster for that point.
/** Assign each point to the nearest center.
* @param points The data points.
* @param center The cluster centers.
*/
def assign(
dists: Tensor2[Point, Cluster, Float32]
def assignPoints(
points: Tensor2[Point, Dim, Float32],
centers: Tensor2[Cluster, Dim, Float32]
): Tensor1[Point, Int32] =
dists.argmin(Axis[Cluster])
val distances = pairwiseDistances(points, centers)
distances.argmin(Axis[Cluster])

/** Update the centroids by computing the mean of the points assigned to each cluster.
/** Update the centers by computing the mean of the points assigned to each cluster.
*
* @param data The original data points
* @param points The data points
* @param assignments The cluster index assigned to each point
* @param nPoints The total number of data points
* @param nClusters The total number of clusters (centroids)
* @return A tensor of updated centroids, one for each cluster.
* @return The new cluster centers
*/
def updateCentroids(
data: Tensor2[Point, Dim, Float32],
private def calcCenters(
points: Tensor2[Point, Dim, Float32],
assignments: Tensor1[Point, Int32]
): Tensor2[Cluster, Dim, Float32] =
val nPoints = data.shape.extent(Axis[Point]).size
val centers = (0 until nClusters).map: clusterIndex =>
val clusterMembershipMask = assignments.elementEquals(Tensor.like(assignments).fill(clusterIndex))
calcCenter(where(clusterMembershipMask.broadcastTo(points.shape), points, Tensor.like(points).fill(Float.NaN)))

val centroids = (0 until nClusters).map { c =>
stack(centers, Axis[Cluster])

val isMember = assignments.elementEquals(Tensor.like(assignments).fill(c))
val belongs = isMember.asFloat32 // Bool -> Float32 for arithmetic
/** Perform one iteration of the EM algorithm: assign points to nearest centers (E-step) and then update centers (M-step).
* @param points The data points
* @param centers The current cluster centers
* @return The updated cluster centers
*/
private def emStep(
points: Tensor2[Point, Dim, Float32],
centers: Tensor2[Cluster, Dim, Float32]
): Tensor2[Cluster, Dim, Float32] =
val assignments = assignPoints(points, centers) // E-step: Assign points to nearest centers
calcCenters(points, assignments) // M-step: Update centers based on assignments

val nMembers: Tensor0[Float32] = belongs.sum(Axis[Point])
val clusterSum: Tensor1[Dim, Float32] =
zipvmap(Axis[Point])(belongs, data) { (b, point) =>
point *! b
}.sum(Axis[Point])
/** Run the trajectory of the K-Means algorithm
* @param points The data points
* @param centers The initial cluster centers (starting point)
* @return Cluster centers at each iteration, ending with the final converged centers.
*/
def run(
points: Tensor2[Point, Dim, Float32],
centers: Tensor2[Cluster, Dim, Float32]
): LazyList[Tensor2[Cluster, Dim, Float32]] =
val nextCenters = emStep(points, centers)
val converged = nextCenters.elementEquals(centers).all.item
centers #:: (if converged then LazyList.empty else run(points, nextCenters))

/** Main entry point to run K-Means clustering.
* @param points The data points
* @param key A random key for random center initialization
* @param maxIterations Maximum number of iterations to run (in case of non-convergence)
* @return The final cluster centers (after convergence or reaching max iterations)
*/
def apply(
points: Tensor2[Point, Dim, Float32],
key: Random.Key,
maxIterations: Int
): Tensor2[Cluster, Dim, Float32] =
val steps = run(points, initializeCenters(points, key))
steps.take(maxIterations).last

clusterSum /! nMembers
}
/** Initialize cluster centers by randomly selecting points from the dataset.
* @param points The data points
* @param key A random key for random selection
* @return Initial cluster centers randomly chosen from the data points
*/
private def initializeCenters(
points: Tensor2[Point, Dim, Float32],
key: Random.Key
): Tensor2[Cluster, Dim, Float32] =
val perm = Random.permutation(points.shape.extent(Axis[Point]))(key)
points.take(Axis[Point])(perm.slice(Axis[Point].at(0 until nClusters))).relabel(Axis[Point].as(Axis[Cluster]))

stack(centroids, Axis[Cluster])
/** Example implementation of the K-Means clustering algorithm, showcasing the use
* of Named Tensors for clear and type-safe code.
*/
class KMeans(nClusters: Int) extends CenterBasedClustering(nClusters):

/** Run the K-Means algorithm on the given data, returning the final centroids and point assignments.
* @param data The input data points to cluster
* @param maxIter The maximum number of iterations to run
* @param seed The random seed for initialization
* @return The final centroids and point assignments
*/
def run(
data: Tensor2[Point, Dim, Float32],
maxIter: Int = 50,
seed: Int = 0
): (Tensor2[Cluster, Dim, Float32], Tensor1[Point, Int32]) =

val nPoints = data.shape.extent(Axis[Point]).size

val perm = Random.permutation(data.shape.extent(Axis[Point]))(Random.Key(seed))
val initCentroids: Tensor2[Cluster, Dim, Float32] =
data
.take(Axis[Point])(perm.slice(Axis[Point].at(0 until nClusters)))
.relabel(Axis[Point].as(Axis[Cluster]))

@annotation.tailrec
def loop(
centroids: Tensor2[Cluster, Dim, Float32],
prevAssignments: Option[Tensor1[Point, Int32]],
iter: Int
): (Tensor2[Cluster, Dim, Float32], Tensor1[Point, Int32]) =

val assignments = assign(squaredDistances(data, centroids))
val converged = prevAssignments.exists(prev => (prev === assignments).item)

if converged || iter >= maxIter then
(centroids, assignments)
else
loop(
updateCentroids(data, assignments),
Some(assignments),
iter + 1
)

loop(initCentroids, None, 0)
override def distance(point: Tensor1[Dim, Float32], centroid: Tensor1[Dim, Float32]): Tensor0[Float32] =
(point - centroid).pow(2).sum

object KMeans:
override def calcCenter(clusterPoints: Tensor2[Point, Dim, Float32]): Tensor1[Dim, Float32] =
clusterPoints.nanmean(Axis[Point])

class KMedians(nClusters: Int) extends CenterBasedClustering(nClusters):

override def distance(point: Tensor1[Dim, Float32], centroid: Tensor1[Dim, Float32]): Tensor0[Float32] =
(point - centroid).abs.sum

trait Point derives Label // N data points to cluster
trait Dim derives Label // D feature dimensions (coordinates of each point)
trait Cluster derives Label // K centroids — the result of clustering
override def calcCenter(clusterPoints: Tensor2[Point, Dim, Float32]): Tensor1[Dim, Float32] =
clusterPoints.nanmedian(Axis[Point])

object KMeans:

//
// Example usage
//
@main def runKMeans(): Unit =
dimwit.initialize()
@main def run(): Unit =

val n = 50 // points per cluster
val k = 3 // clusters to discover

val rootKey = Random.Key(1337)
val (key0, rest) = rootKey.split2()
val (key1, key2) = rest.split2()

val centers = Array(
val trueCenters = Array(
Array(-3.0f, 0.0f), // cloud 0: lower-left
Array(3.0f, 0.0f), // cloud 1: lower-right
Array(0.0f, 4.0f) // cloud 2: top-center
)
def cloud(center: Array[Float], key: Random.Key): Tensor2[Point, Dim, Float32] =
val noise = Normal.standardNormal(Shape(Axis[Point] -> n, Axis[Dim] -> 2)).sample(key)
noise +! Tensor1(Axis[Dim]).fromArray(center)

val data: Tensor2[Point, Dim, Float32] =
def evaluate(alg: CenterBasedClustering, name: String, points: Tensor2[Point, Dim, Float32], trainKey: Random.Key): Unit =
val centers = alg(points, trainKey, maxIterations = 50)
val assignments = alg.assignPoints(points, centers)

val inertia = alg.pairwiseDistances(points, centers).vmap(Axis[Point])(_.min).sum.item

println(f"Inertia (within-cluster sum of distances): $inertia%.2f\n")

// ── Report discovered clusters ─────────────────────────────────────────────
val nTotal = n * k
println(s"Discovered centers (${name}) vs true centers:")
println(f" ${"cluster"}%-9s ${"centroid (x,y)"}%-22s ${"true center"}%-16s size")
println(" " + "─" * 65)

(0 until k).foreach: c =>
// Extract the c-th centroid row: slice along the named Cluster axis.
val center = centers.slice(Axis[Cluster].at(c))
val cx = center.slice(Axis[Dim].at(0)).item
val cy = center.slice(Axis[Dim].at(1)).item

// Cluster size: count points whose assignment equals c.
// elementEquals produces [Point, Bool]; asFloat32 then sum gives a count.
val clusterSize =
assignments
.elementEquals(Tensor(Shape(Axis[Point] -> nTotal)).fill(c))
.asFloat32.sum.item.toInt

val tc = trueCenters(c)
println(f" cluster $c%-7d ($cx%+6.2f, $cy%+6.2f) (${tc(0)}%+6.2f, ${tc(1)}%+6.2f) $clusterSize")

def prepareData(dataKey: Random.Key): Tensor2[Point, Dim, Float32] =
def cloud(center: Array[Float], key: Random.Key): Tensor2[Point, Dim, Float32] =
val noise = Normal.standardNormal(Shape(Axis[Point] -> n, Axis[Dim] -> center.length)).sample(key)
Tensor1(Axis[Dim]).fromArray(center) +! noise

val (key0, key1, key2) = dataKey.splitToTuple(3)
concatenate(
concatenate(cloud(centers(0), key0), cloud(centers(1), key1), Axis[Point]),
cloud(centers(2), key2),
List(
cloud(trueCenters(0), key0),
cloud(trueCenters(1), key1),
cloud(trueCenters(2), key2)
),
Axis[Point]
)

dimwit.initialize()

val rootKey = Random.Key(1337)
val (dataKey, trainKey) = rootKey.split2()

val points = prepareData(dataKey)

println(s"Dataset: ${n * k} points, 2 dimensions, $k true clusters\n")
val kmeans = KMeans(k)
val (centroids, assignments) = kmeans.run(data, maxIter = 30)

val inertia: Float =
kmeans.squaredDistances(data, centroids).vmap(Axis[Point])(_.min).sum.item

println(f"Inertia (within-cluster sum of squared distances): $inertia%.2f\n")

// ── Report discovered clusters ─────────────────────────────────────────────
val nTotal = n * k
println("Discovered centroids (k-means) vs true centers:")
println(f" ${"cluster"}%-9s ${"centroid (x,y)"}%-22s ${"true center"}%-16s size")
println(" " + "─" * 65)

(0 until k).foreach { c =>
// Extract the c-th centroid row: slice along the named Cluster axis.
val centroid = centroids.slice(Axis[Cluster].at(c))
val cx = centroid.slice(Axis[Dim].at(0)).item
val cy = centroid.slice(Axis[Dim].at(1)).item

// Cluster size: count points whose assignment equals c.
// elementEquals produces [Point, Bool]; asFloat32 then sum gives a count.
val clusterSize =
assignments
.elementEquals(Tensor(Shape(Axis[Point] -> nTotal)).fill(c))
.asFloat32.sum.item.toInt

val tc = centers(c)
println(f" cluster $c%-7d ($cx%+6.2f, $cy%+6.2f) (${tc(0)}%+6.2f, ${tc(1)}%+6.2f) $clusterSize")
}
val (kmeansKey, kmediansKey) = trainKey.split2()
evaluate(new KMeans(k), "K-Means", points, kmeansKey)
evaluate(new KMedians(k), "K-Median", points, kmediansKey)
Loading
Loading