diff --git a/core/src/main/scala/dimwit/linalg/LinearAlgebra.scala b/core/src/main/scala/dimwit/linalg/LinearAlgebra.scala new file mode 100644 index 0000000..5360dc9 --- /dev/null +++ b/core/src/main/scala/dimwit/linalg/LinearAlgebra.scala @@ -0,0 +1,203 @@ +package dimwit.linalg + +import dimwit.jax.Jax +import dimwit.tensor.Axis +import dimwit.tensor.Label +import dimwit.tensor.Labels +import dimwit.tensor.ShapeTypeHelpers.AxesRemover +import dimwit.tensor.Tensor +import dimwit.tensor.Tensor0 +import dimwit.tensor.Tensor1 +import dimwit.tensor.Tensor2 +import dimwit.tensor.TensorOps.IsFloating +import dimwit.tensor.TensorOps.IsNumber +import me.shadaj.scalapy.py +import me.shadaj.scalapy.py.SeqConverters + +/** Common linear algebra operations. + */ +object LinearAlgebra: + + enum VectorNormType: + case L1 + case L2 + case Ord(p: Double) + case Inf + + enum MatrixNormType: + case Frobenius + case Nuclear + case Spectral + case One + case Inf + + enum QRMode: + case Reduced + case Complete + + /** Computes the determinant of the tensor `t` along the specified axes (L1, L2) + * + * @param t The input tensor from which to compute the determinant. + * @param axis1 The first axis along which to compute the determinant. + * @param axis2 The second axis along which to compute the determinant. + * @return A new tensor with the determinant computed, where the two specified axes are removed + */ + def det[T <: Tuple: Labels, L1: Label, L2: Label, V: IsFloating](t: Tensor[T, V], axis1: Axis[L1], axis2: Axis[L2])(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = + // JAX det only works on the last two axes (-2, -1). We must move the user's selected axes to the end. + val moved = Jax.jnp.moveaxis( + t.jaxValue, + source = ev.indices.toPythonProxy, + destination = Seq(-2, -1).toPythonProxy + ) + Tensor(Jax.jnp.linalg.det(moved)) + + /** Extracts the diagonal along the given two axes (with optional offset), + * replacing them by a new 1D axis labeled L1. + * + * @param t The input tensor from which to extract the diagonal. + * @param axis1 The first axis along which to extract the diagonal. + * @param axis2 The second axis along which to extract the diagonal. + * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. + * @return A new tensor with the diagonal extracted, where the two specified axes are replaced by a new 1D axis labeled L1. + */ + def diagonal[T <: Tuple, L1: Label, L2: Label, V](t: Tensor[T, V], axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes *: L1 *: EmptyTuple, V] = + Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + + /** Computes the inverse of the tensor t along the last two axes. + * The first axes of the tensors are preserved, while the last + * two axes are replaced by their inverses. + * The tensor must be square along the last two axes. + * + * @return a new tensor with the same shape as t, but with the last two axes replaced by their inverses. + */ + def inv[T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]): Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) + + /** Computes the trace of the tensor `t` along the specified axes (L1, L2) with an optional offset. + * The resulting tensor has the two specified axes removed, and the remaining axes are preserved. + * + * @param t The input tensor from which to compute the trace. + * @param axis1 The first axis along which to compute the trace. + * @param axis2 The second axis along which to compute the trace. + * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. + * + * @return A new tensor with the trace computed, where the two specified axes are removed, and the remaining axes are preserved. + */ + def trace[T <: Tuple, L1: Label, L2: Label, V: IsNumber](t: Tensor[T, V], axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using + ev: AxesRemover[T, (L1, L2)], + labels: Labels[ev.RemainingAxes] + ): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + + /** Computes the vector norm of the tensor `t` based on the specified `normType`. + * + * @param t The input tensor for which to compute the norm. + * @param normType The type of norm to compute (L1, L2, Ord(p), or Inf). + * @return A new 0-D tensor containing the computed norm of the input tensor. + */ + def norm[L1: Label, V: IsFloating](t: Tensor1[L1, V], normType: VectorNormType): Tensor0[V] = + normType match + case VectorNormType.L1 => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = 1)) + case VectorNormType.L2 => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = 2)) + case VectorNormType.Ord(p) => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = p)) + case VectorNormType.Inf => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = Jax.jnp.inf)) + + /** Computes the matrix norm of the tensor `t` based on the specified `normType`. + * + * @param t The input tensor for which to compute the norm. + * @param normType The type of norm to compute (Frobenius, Nuclear, Spectral, One, or Inf). + * @return A new 0-D tensor containing the computed norm of the input tensor. + */ + def norm[L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V], normType: MatrixNormType): Tensor0[V] = + normType match + case MatrixNormType.Frobenius => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = "fro")) + case MatrixNormType.Nuclear => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = "nuc")) + case MatrixNormType.Spectral => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = 2)) + case MatrixNormType.One => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = 1)) + case MatrixNormType.Inf => Tensor0(Jax.jnp.linalg.norm(t.jaxValue, ord = Jax.jnp.inf)) + + /** Computes the element-wise norm of the tensor `t` l2 norm along the last axis. + * @param t The input tensor for which to compute the norm. + * + * @return A new 0-D tensor containing the computed norm of the input tensor. + */ + def norm[T <: Tuple, V: IsFloating](t: Tensor[T, V]): Tensor0[V] = + Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) + + /** Cholesky factorization. + * + * @param t The input tensor to be factorized. It must be a symmetric positive-definite matrix. + * @param upper If true, the upper-triangular Cholesky factor is returned + * @param symmetrizeInput If true, the input matrix is symmetrized before factorization to ensure numerical stability. + * @return a triangular matrix representing the cholesky factor + * + * @see [[https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.cholesky.html#jax.numpy.linalg.cholesky JAX documentation]] for more details on the underlying implementation. + */ + def cholesky[L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V], upper: Boolean = false, symmetrizeInput: Boolean = true): Tensor2[L1, L2, V] = + Tensor(Jax.jnp.linalg.cholesky(t.jaxValue, upper = upper, symmetrize_input = symmetrizeInput)) + + /** Computes the QR factorization of the tensor `t`. + * + * @param t The input tensor to be factorized. It must be a 2D matrix. + * @param mode The mode of the QR factorization (Reduced or Complete). + * @return A tuple containing two tensors: the orthogonal matrix Q and the upper-triangular matrix R. + * + * @see [[https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.qr.html#jax.numpy.linalg.qr JAX documentation]] for more details on the underlying implementation. + */ + def qr[L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V], mode: QRMode = QRMode.Reduced): (q: Tensor2[L1, L2, V], r: Tensor2[L1, L2, V]) = + val qr = Jax.jnp.linalg.qr( + t.jaxValue, + mode = mode match + case QRMode.Reduced => "reduced" + case QRMode.Complete => "complete" + ) + (q = Tensor(qr.bracketAccess(0)), r = Tensor(qr.bracketAccess(1))) + + /** Computes the eigenvalues and eigenvectors of a symmetric matrix `t`. + * @param t The input tensor representing a symmetric matrix. + * @param upper If true, the upper-triangular part of the matrix is used. + * @param symmetrizeInput If true, the input matrix is symmetrized before computation to ensure numerical stability. + * @return A tuple containing two tensors: the eigenvalues and the corresponding eigenvectors of the input matrix. + * + * @see [[https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.eigh.html#jax.numpy.linalg.eigh JAX documentation]] for more details on the underlying implementation. + */ + def eigh[L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V], upper: Boolean = false, symmetrizeInput: Boolean = true) + : (eigenvalues: Tensor1[L1, V], eigenvectors: Tensor2[L1, L2, V]) = + + val ret = Jax.jnp.linalg.eigh(t.jaxValue, UPLO = if upper then "U" else "L", symmetrize_input = symmetrizeInput) + val eigenvalues: Tensor1[L1, V] = Tensor(ret.bracketAccess(0)) + val eigenvectors: Tensor2[L1, L2, V] = Tensor(ret.bracketAccess(1)) + (eigenvalues = eigenvalues, eigenvectors = eigenvectors) + + /** Computes the singular value decomposition (SVD) of the tensor `t`. + * + * @param t The input tensor to be decomposed. + * @param fullMatrices If true, compute the full-sized U and Vh matrices; if false, compute the reduced-sized matrices. + * @param hermitian If true, the input is a Hermitian matrix. + * @return A tuple containing three tensors: the left singular vectors U, the singular values S, and the right singular vectors Vh. + * + * @see [[https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.svd.html#jax.numpy.linalg.svd JAX documentation]] for more details on the underlying implementation. + */ + def svd[L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V], fullMatrices: Boolean = false, hermitian: Boolean = false) + : (U: Tensor2[L1, L2, V], S: Tensor1[L1, V], Vh: Tensor2[L1, L2, V]) = + + val ret = Jax.jnp.linalg.svd(t.jaxValue, full_matrices = fullMatrices, hermitian = hermitian) + val u: Tensor2[L1, L2, V] = Tensor(ret.bracketAccess(0)) + val s: Tensor1[L1, V] = Tensor(ret.bracketAccess(1)) + val vh: Tensor2[L1, L2, V] = Tensor(ret.bracketAccess(2)) + (U = u, S = s, Vh = vh) + + /** Solves the linear equation Ax = b for x, where A is a square matrix and b is a vector. + * + * @param a The input tensor representing the square matrix A. + * @param b The input tensor representing the vector b. + * @return A new tensor containing the solution vector x. + * + * @see [[https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.solve.html#jax.numpy.linalg.solve JAX documentation]] for more details on the underlying implementation. + */ + def solve[L1: Label, L2: Label, V: IsFloating](a: Tensor2[L1, L2, V], b: Tensor1[L1, V]): Tensor1[L2, V] = + Tensor(Jax.jnp.linalg.solve(a.jaxValue, b.jaxValue)) diff --git a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala index db52ae1..6b67a42 100644 --- a/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala +++ b/core/src/main/scala/dimwit/tensor/tensorops/LinearAlgebraOps.scala @@ -1,6 +1,6 @@ package dimwit.tensor.tensorops -import dimwit.jax.Jax +import dimwit.linalg.LinearAlgebra import dimwit.tensor.Axis import dimwit.tensor.Label import dimwit.tensor.Labels @@ -11,9 +11,6 @@ import dimwit.tensor.Tensor1 import dimwit.tensor.Tensor2 import dimwit.tensor.TensorOps.IsFloating import dimwit.tensor.TensorOps.IsNumber -import me.shadaj.scalapy.py -import me.shadaj.scalapy.py.SeqConverters -import me.shadaj.scalapy.readwrite.Writer object LinearAlgebraOps: @@ -22,33 +19,25 @@ object LinearAlgebraOps: /** Extracts the diagonal along the given two axes (with optional offset), * replacing them by a new 1D axis labeled L1. * - * @param axis1 The first axis along which to extract the diagonal. - * @param axis2 The second axis along which to extract the diagonal. - * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. - * @return A new tensor with the diagonal extracted, where the two specified axes are replaced by a new 1D axis labeled L1. + * @see [[LinearAlgebra.diagonal]] for details */ def diagonal[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using ev: AxesRemover[T, (L1, L2)], labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes *: L1 *: EmptyTuple, V] = - Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + ): Tensor[ev.RemainingAxes *: L1 *: EmptyTuple, V] = LinearAlgebra.diagonal(t, axis1, axis2, offset) extension [L1: Label, L2: Label, V](t: Tensor2[L1, L2, V]) /** return the diagonal of the tensor `t` along the specified axes. - * The resulting 1D tensor has a single axis labeled L1, representing the diagonal index over the original (L1, L2) axes. - * - * @return A new tensor1 with representing the diagonal. It uses the Label of the first axis (L1) as the label for the resulting 1D tensor. + * @see [[LinearAlgebra.diagonal]] for details */ - def diagonal: Tensor1[L1, V] = t.diagonal(0) + def diagonal: Tensor1[L1, V] = LinearAlgebra.diagonal(t, Axis[L1], Axis[L2]).asInstanceOf[Tensor[Tuple1[L1], V]] /** return the diagonal of the tensor `t` along the specified axes. - * The resulting 1D tensor has a single axis labeled L1, representing the diagonal index over the original (L1, L2) axes. - * - * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. - * @return A new tensor1 with representing the diagonal. It uses the Label of the first axis (L1) as the label for the resulting 1D tensor. + * @see [[LinearAlgebra.diagonal]] for details */ - def diagonal(offset: Int): Tensor1[L1, V] = Tensor(Jax.jnp.diagonal(t.jaxValue, offset = offset)) + def diagonal(offset: Int): Tensor1[L1, V] = + LinearAlgebra.diagonal(t, Axis[L1], Axis[L2], offset).asInstanceOf[Tensor[Tuple1[L1], V]] // --------------------------------------------------------- // IsNumber operations (IsFloat or IsInt) @@ -56,66 +45,48 @@ object LinearAlgebraOps: extension [T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V]) - /** Computes the trace of the tensor `t` along the specified axes (L1, L2) with an optional offset. - * The resulting tensor has the two specified axes removed, and the remaining axes are preserved. - * - * @param axis1 The first axis along which to compute the trace. - * @param axis2 The second axis along which to compute the trace. - * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, while negative values indicate diagonals below it. + /** Computes the trace of the tensor. * - * @return A new tensor with the trace computed, where the two specified axes are removed, and the remaining axes are preserved. + * @see [[LinearAlgebra.trace]] for details */ def trace[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2], offset: Int = 0)(using ev: AxesRemover[T, (L1, L2)], labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = Tensor(Jax.jnp.trace(t.jaxValue, offset = offset, axis1 = ev.indices(0), axis2 = ev.indices(1))) + ): Tensor[ev.RemainingAxes, V] = LinearAlgebra.trace(t, axis1, axis2, offset) extension [L1: Label, L2: Label, V: IsNumber](t: Tensor2[L1, L2, V]) /** Computes the trace of the tensor + * + * @see [[LinearAlgebra.trace]] for details */ def trace: Tensor0[V] = t.trace(0) - /** Computes the trace of the tensor with an optional offset. - * - * @param offset The offset of the diagonal from the main diagonal. Positive values indicate diagonals above the main diagonal, - * while negative values indicate diagonals below it. - */ - def trace(offset: Int): Tensor0[V] = Tensor0(Jax.jnp.trace(t.jaxValue, offset = offset)) + /** Computes the trace. @see [[LinearAlgebra.trace]] for details */ + def trace(offset: Int): Tensor0[V] = LinearAlgebra.trace(t, Axis[L1], Axis[L2], offset) extension [T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V]) - /** Computes the L2 norm of the tensor t. + /** Computes the element wise L2 norm of the tensor t. + * + * @see [[LinearAlgebra.norm]] for details */ - def norm: Tensor0[V] = Tensor0(Jax.jnp.linalg.norm(t.jaxValue)) + def norm: Tensor0[V] = LinearAlgebra.norm(t) - /** Computes the inverse of the tensor t along the last two axes. - * The first axes of the tensors are preserved, while the last - * two axes are replaced by their inverses. - * The tensor must be square along the last two axes. - * - * @return a new tensor with the same shape as t, but with the last two axes replaced by their inverses. + /** @see [[LinearAlgebra.inv]] for details */ - def inv: Tensor[T, V] = Tensor(Jax.jnp.linalg.inv(t.jaxValue)) + def inv: Tensor[T, V] = LinearAlgebra.inv(t) /** Computes the determinant of the tensor `t` along the specified axes (L1, L2) - * - * @param axis1 The first axis along which to compute the determinant. - * @param axis2 The second axis along which to compute the determinant. - * @return A new tensor with the determinant computed, where the two specified axes are removed + * @see [[LinearAlgebra.det]] for details */ def det[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2])(using ev: AxesRemover[T, (L1, L2)], labels: Labels[ev.RemainingAxes] - ): Tensor[ev.RemainingAxes, V] = - // JAX det only works on the last two axes (-2, -1). We must move the user's selected axes to the end. - val moved = Jax.jnp.moveaxis( - t.jaxValue, - source = ev.indices.toPythonProxy, - destination = Seq(-2, -1).toPythonProxy - ) - Tensor(Jax.jnp.linalg.det(moved)) + ): Tensor[ev.RemainingAxes, V] = LinearAlgebra.det(t, axis1, axis2) extension [L1: Label, L2: Label, V: IsFloating](t: Tensor2[L1, L2, V]) - /** computes the determinant of the 2-D tensor t */ - def det: Tensor0[V] = Tensor0(Jax.jnp.linalg.det(t.jaxValue)) + /** computes the determinant of the 2-D tensor t + * @see [[LinearAlgebra.det]] for details + */ + def det: Tensor0[V] = LinearAlgebra.det(t, Axis[L1], Axis[L2]) diff --git a/core/src/test/scala/dimwit/linalg/LinearAlgebraTests.scala b/core/src/test/scala/dimwit/linalg/LinearAlgebraTests.scala new file mode 100644 index 0000000..dbaac38 --- /dev/null +++ b/core/src/test/scala/dimwit/linalg/LinearAlgebraTests.scala @@ -0,0 +1,169 @@ +package dimwit.linalg + +import dimwit.* + +class LinearAlgebraTests extends DimwitTest: + + describe("Vector norms"): + val v = Tensor1(Axis[A]).fromArray(Array(3.0f, 4.0f)) + + it("L1 norm"): + LinearAlgebra.norm(v, LinearAlgebra.VectorNormType.L1).item shouldBe 7.0f +- 1e-5f + + it("L2 norm"): + LinearAlgebra.norm(v, LinearAlgebra.VectorNormType.L2).item shouldBe 5.0f +- 1e-5f + + it("Ord(1) norm equals L1"): + LinearAlgebra.norm(v, LinearAlgebra.VectorNormType.Ord(1)).item shouldBe + LinearAlgebra.norm(v, LinearAlgebra.VectorNormType.L1).item +- 1e-5f + + it("Ord(3) norm"): + // (3^3 + 4^3)^(1/3) = (27 + 64)^(1/3) = 91^(1/3) ≈ 4.4979 + LinearAlgebra.norm(v, LinearAlgebra.VectorNormType.Ord(3)).item shouldBe + Math.pow(91.0, 1.0 / 3.0).toFloat +- 1e-4f + + it("Inf norm (max abs value)"): + LinearAlgebra.norm(v, LinearAlgebra.VectorNormType.Inf).item shouldBe 4.0f +- 1e-5f + + describe("Matrix norms"): + // [[3, 0], [4, 0]]: easy to reason about column/row sums + val m = Tensor2(Axis[A], Axis[B]).fromArray( + Array(Array(3.0f, 0.0f), Array(4.0f, 0.0f)) + ) + + it("Frobenius norm"): + // sqrt(3^2 + 4^2) = 5 + LinearAlgebra.norm(m, LinearAlgebra.MatrixNormType.Frobenius).item shouldBe 5.0f +- 1e-5f + + it("Nuclear norm"): + // singular values of [[3,0],[4,0]] are 5 and 0; nuclear = sum = 5 + LinearAlgebra.norm(m, LinearAlgebra.MatrixNormType.Nuclear).item shouldBe 5.0f +- 1e-5f + + it("Spectral norm (ord=2)"): + // largest singular value = 5 + LinearAlgebra.norm(m, LinearAlgebra.MatrixNormType.Spectral).item shouldBe 5.0f +- 1e-5f + + it("One norm (max absolute column sum)"): + // col 0 sum = 3+4=7, col 1 sum = 0 → 7 + LinearAlgebra.norm(m, LinearAlgebra.MatrixNormType.One).item shouldBe 7.0f +- 1e-5f + + it("Inf norm (max absolute row sum)"): + // row 0 sum = 3, row 1 sum = 4 → 4 + LinearAlgebra.norm(m, LinearAlgebra.MatrixNormType.Inf).item shouldBe 4.0f +- 1e-5f + + describe("Cholesky factorization"): + + val lower = LinearAlgebra.cholesky( + Tensor2(Axis[A], Axis[Prime[A]]).fromArray(Array(Array(4.0f, 0f), Array(2.0f, 3.0f))), + upper = false + ) + val spd = lower.dot(Axis[Prime[A]])(lower) // make it symmetric positive-definite + + it("Lower-triangular factor has exact values"): + val L = LinearAlgebra.cholesky(spd, upper = false) + L should approxEqual(lower, tolerance = 1e-5f) + + it("correctly reconstructs the original matrix"): + val L = LinearAlgebra.cholesky(spd, upper = false, symmetrizeInput = false) + val reconstructed = L.dot(Axis[Prime[A]])(L) + reconstructed should approxEqual(spd, tolerance = 1e-5f) + + // Shared diagonal test matrix for eigh/svd: [[3, 0], [0, 5]] + // Eigenvalues (ascending): [3, 5]; singular values (descending): [5, 3] + val diagMat = Tensor2(Axis[A], Axis[Prime[A]]).fromArray( + Array(Array(3.0f, 0.0f), Array(0.0f, 5.0f)) + ) + // identity as Tensor2[A, Prime[A]] — the type produced by contracting Prime[A] from a Tensor2[A, Prime[A]] with itself + val identityAP = Tensor2(Axis[A], Axis[Prime[A]]).fromArray( + Array(Array(1.0f, 0.0f), Array(0.0f, 1.0f)) + ) + + describe("Eigendecomposition (eigh)"): + + it("eigenvalues of a diagonal matrix are its diagonal entries (ascending)"): + val (eigenvalues, _) = LinearAlgebra.eigh(diagMat) + eigenvalues should approxEqual( + Tensor1(Axis[A]).fromArray(Array(3.0f, 5.0f)), + tolerance = 1e-5f + ) + + it("eigenvalues sum equals trace"): + val (eigenvalues, _) = LinearAlgebra.eigh(diagMat) + eigenvalues.sum.item shouldBe diagMat.sum.item +- 1e-4f + + it("eigenvectors of a diagonal matrix are the standard basis (up to sign)"): + val (_, eigenvectors) = LinearAlgebra.eigh(diagMat) + // |V| should be identity (sign-agnostic) + val absEigvecs = eigenvectors.abs + val expected = Tensor2(Axis[A], Axis[Prime[A]]).fromArray( + Array(Array(1.0f, 0.0f), Array(0.0f, 1.0f)) + ) + absEigvecs should approxEqual(expected, tolerance = 1e-5f) + + it("eigenvectors are orthonormal: V @ V^T = I"): + val (_, eigenvectors) = LinearAlgebra.eigh(diagMat) + val vvt = eigenvectors.dot(Axis[Prime[A]])(eigenvectors) + vvt should approxEqual(identityAP, tolerance = 1e-5f) + + describe("QR factorization"): + + // Non-trivial 2×2 matrix; expected properties are sign-agnostic + val qrMat = Tensor2(Axis[A], Axis[Prime[A]]).fromArray( + Array(Array(3.0f, 2.0f), Array(4.0f, 1.0f)) + ) + + it("Q is (column-)orthonormal: Q @ Q^T = I"): + val (q, _) = LinearAlgebra.qr(qrMat) + val qqt = q.dot(Axis[Prime[A]])(q) + qqt should approxEqual(identityAP, tolerance = 1e-5f) + + it("R is upper triangular: lower-left element is zero"): + val (_, r) = LinearAlgebra.qr(qrMat) + r.slice(Axis[A].at(1)).slice(Axis[Prime[A]].at(0)).item shouldBe 0.0f +- 1e-5f + + it("Frobenius norm is preserved: ||A||_F = ||R||_F (since Q is orthogonal)"): + val (_, r) = LinearAlgebra.qr(qrMat) + LinearAlgebra.norm(r, LinearAlgebra.MatrixNormType.Frobenius).item shouldBe + LinearAlgebra.norm(qrMat, LinearAlgebra.MatrixNormType.Frobenius).item +- 1e-4f + + describe("Singular value decomposition (SVD)"): + + it("singular values of a diagonal matrix are its diagonal entries (descending)"): + val (_, s, _) = LinearAlgebra.svd(diagMat) + s should approxEqual( + Tensor1(Axis[A]).fromArray(Array(5.0f, 3.0f)), + tolerance = 1e-5f + ) + + it("singular values sum equals nuclear norm"): + val (_, s, _) = LinearAlgebra.svd(diagMat) + s.sum.item shouldBe + LinearAlgebra.norm(diagMat, LinearAlgebra.MatrixNormType.Nuclear).item +- 1e-4f + + it("largest singular value equals spectral norm"): + val (_, s, _) = LinearAlgebra.svd(diagMat) + s.max.item shouldBe + LinearAlgebra.norm(diagMat, LinearAlgebra.MatrixNormType.Spectral).item +- 1e-4f + + it("U is orthonormal: U @ U^T = I"): + val (u, _, _) = LinearAlgebra.svd(diagMat) + // u: Tensor2[A, Prime[A]], contracting Prime[A] gives Tensor2[A, Prime[A]] + val uut = u.dot(Axis[Prime[A]])(u) + uut should approxEqual(identityAP, tolerance = 1e-5f) + + it("Vh is orthonormal: Vh @ Vh^T = I"): + val (_, _, vh) = LinearAlgebra.svd(diagMat) + // vh: Tensor2[A, Prime[A]], contracting Prime[A] gives Tensor2[A, Prime[A]] + val vhvht = vh.dot(Axis[Prime[A]])(vh) + vhvht should approxEqual(identityAP, tolerance = 1e-5f) + + describe("Linear solve (Ax = b)"): + // A = [[2, 1], [1, 3]], b = [5, 10] → exact solution x = [1, 3] + val solveA = Tensor2(Axis[A], Axis[Prime[A]]).fromArray( + Array(Array(2.0f, 1.0f), Array(1.0f, 3.0f)) + ) + val solveB = Tensor1(Axis[A]).fromArray(Array(5.0f, 10.0f)) + + it("solution satisfies A x = b"): + val x = LinearAlgebra.solve(solveA, solveB) + x should approxEqual(Tensor1(Axis[Prime[A]]).fromArray(Array(1.0f, 3.0f)), tolerance = 1e-5f)