Core API Reference

Core modules provide the building blocks used by all models: distance functions, loss functions, initializers, and utilities.

Distance Functions

JAX-based distance functions for Prosemble.

This module provides GPU-accelerated, vectorized distance computations using JAX. All functions are JIT-compiled for maximum performance.

Mathematical Background

Distance metrics are fundamental to prototype-based learning algorithms. This implementation focuses on: 1. Batch/matrix operations (no Python loops) 2. GPU compatibility 3. JIT compilation for speed 4. Numerical stability

Author: Prosemble Contributors License: MIT

prosemble.core.distance.euclidean_distance_matrix(X, Y)[source]

Compute pairwise Euclidean distances between rows of X and Y.

Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).

Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)

Parameters:
Returns:

Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)

Return type:

D

Complexity:

Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix

Example

>>> X = jnp.array([[0, 0], [1, 1], [2, 2]])
>>> Y = jnp.array([[0, 0], [3, 3]])
>>> D = euclidean_distance_matrix(X, Y)
>>> D.shape
(3, 2)
>>> D[0, 0]  # Distance from X[0] to Y[0]
0.0
>>> D[2, 1]  # Distance from X[2] to Y[1]
1.414...

Notes

  • Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives

  • GPU-compatible: All operations are JAX primitives

  • JIT-compiled: First call compiles, subsequent calls are fast

prosemble.core.distance.squared_euclidean_distance_matrix(X, Y)[source]

Compute pairwise squared Euclidean distances.

More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.

Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)

Parameters:
Returns:

Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)

Return type:

D

Complexity:

Time: O(nmd) Space: O(nm)

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D_sq = squared_euclidean_distance_matrix(X, Y)
>>> D_sq[1, 1]  # Squared distance from [1,1] to [2,2]
2.0

Notes

  • Preferred over euclidean when squared distances are sufficient

  • Many algorithms (FCM, PCM) use squared distances directly

  • More numerically stable than squaring euclidean distances

prosemble.core.distance.manhattan_distance_matrix(X, Y)[source]

Compute pairwise Manhattan (L1) distances.

Formula: \(D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|\)

Parameters:
Returns:

Array of shape (n, m) where D[i,j] is Manhattan distance

Return type:

D

Complexity:

Time: O(nmd) Space: O(nmd) - intermediate broadcasting

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D = manhattan_distance_matrix(X, Y)
>>> D[1, 1]  # Manhattan distance from [1,1] to [2,2]
2.0
Implementation:

Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d) Then sums absolute differences along feature dimension.

Notes

  • Also known as “taxicab” or “city block” distance

  • More robust to outliers than Euclidean distance

  • Natural for sparse/binary features

prosemble.core.distance.lpnorm_distance_matrix(X, Y, p)[source]

Compute pairwise L-p norm distances.

Formula: \(D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}\)

Special Cases:

p = 1: Manhattan distance p = 2: Euclidean distance p = \(\infty\): Chebyshev distance (max absolute difference)

Parameters:
Returns:

Array of shape (n, m) where D[i,j] is L-p distance

Return type:

D

Complexity:

Time: O(nmd) Space: O(nmd)

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [3, 4]])
>>> D = lpnorm_distance_matrix(X, Y, p=2)  # Euclidean
>>> D = lpnorm_distance_matrix(X, Y, p=1)  # Manhattan
>>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf)  # Chebyshev

Notes

  • For p=1, use manhattan_distance_matrix for better performance

  • For p=2, use euclidean_distance_matrix for better performance

  • For p=inf, computes max(|x - y|)

prosemble.core.distance.omega_distance_matrix(X, Y, omega)[source]

Compute distances in projected space using projection matrix Omega.

Formula: \(D_{ij} = \|X_i \Omega - Y_j \Omega\|^2\) where \(\Omega\) is a projection matrix that transforms the feature space.

Parameters:
Returns:

Array of shape (n, m) with squared distances in projected space

Return type:

D

Complexity:

Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk) Space: O(nk + mk + nm)

Use Cases:
  • Dimensionality reduction for distance computation

  • Learning relevance of features (omega learned from data)

  • Mahalanobis-like distances (when \(\Omega = L\) where \(\Sigma = LL^T\))

Example

>>> X = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> Y = jnp.array([[0, 0, 0], [1, 1, 1]])
>>> omega = jnp.array([[1, 0], [0, 1], [0, 0]])  # Project to first 2 dims
>>> D = omega_distance_matrix(X, Y, omega)
>>> D.shape
(2, 2)

Notes

  • When omega is identity, reduces to squared Euclidean distance

  • When omega is learned, enables adaptive distance metrics

  • Used in GLVQ (Generalized Learning Vector Quantization)

prosemble.core.distance.lomega_distance_matrix(X, Y, omegas)[source]

Compute distances using multiple projection matrices (Local Omega).

Formula: \(D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2\) where \(\Omega_p\) are multiple projection matrices (one per prototype or cluster).

Parameters:
Returns:

Array of shape (n, m) with aggregated projected distances

Return type:

D

Complexity:

Time: O(nmdk) Space: O(nmk)

Use Cases:
  • Local relevance learning (each prototype has its own metric)

  • Adaptive distance metrics in GMLVQ

  • Cluster-specific feature weighting

Example

>>> n, m, d, k = 10, 3, 5, 2
>>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d))
>>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d))
>>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k))
>>> D = lomega_distance_matrix(X, Y, omegas)
>>> D.shape
(10, 3)
Implementation:

Uses einsum for efficient tensor contraction: 1. Project X through each omega: X @ omegas[j] for all j 2. Extract diagonal for Y projections (each Y[j] uses omegas[j]) 3. Compute squared differences and sum

Notes

  • Generalizes omega_distance to local (per-prototype) metrics

  • More flexible but computationally expensive

  • Enables learning which features matter for each cluster

prosemble.core.distance.tangent_distance_matrix(X, Y, omegas)[source]

Compute pairwise localized tangent distances.

Each prototype j has an orthogonal subspace basis \(\Omega_j\) of shape (d, s). The tangent distance projects out the subspace directions:

\[d(x, w_j) = \|(I - \Omega_j \Omega_j^T)(x - w_j)\|^2\]
Parameters:
  • X (array of shape (n, d)) – Data points.

  • Y (array of shape (m, d)) – Prototypes.

  • omegas (array of shape (m, d, s)) – Orthogonal subspace bases per prototype, where s is the subspace dimension.

Returns:

D – Squared tangent distances.

Return type:

array of shape (n, m)

Notes

Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent distances in generalized learning vector quantization.

prosemble.core.distance.wasserstein2_distance_matrix(X, means, log_variances)[source]

Compute pairwise squared 2-Wasserstein distances from points to Gaussian prototypes.

Each prototype is a diagonal Gaussian \(\mathcal{N}(\mu_k, \text{diag}(\sigma_k^2))\). Each input point \(x\) is treated as a Dirac delta distribution \(\delta_x\).

The squared 2-Wasserstein distance from a point to a diagonal Gaussian is:

\[W_2^2(\delta_x, \mathcal{N}(\mu_k, \text{diag}(\sigma_k^2))) = \sum_j (x_j - \mu_{kj})^2 + \sum_j \sigma_{kj}^2\]

This decomposes into the squared Euclidean distance from the point to the mean, plus the total variance (trace of covariance). Prototypes with smaller variance are effectively “more certain” and attract nearby points more strongly.

Parameters:
  • X (array of shape (n, d)) – Data points.

  • means (array of shape (p, d)) – Prototype mean vectors.

  • log_variances (array of shape (p, d)) – Log of prototype variances (ensures positivity via exp).

Returns:

D – Squared 2-Wasserstein distances.

Return type:

array of shape (n, p)

References

prosemble.core.distance.wasserstein2_omega_distance_matrix(X, means, log_variances, omega)[source]

Squared 2-Wasserstein distance with global metric adaptation.

Projects data and means through \(\Omega\) before computing the Euclidean component, while variances contribute directly:

\[W_2^2(x, k) = \|\Omega(x - \mu_k)\|^2 + \sum_j \sigma_{kj}^2\]
Parameters:
  • X (array of shape (n, d)) – Data points.

  • means (array of shape (p, d)) – Prototype mean vectors.

  • log_variances (array of shape (p, d)) – Log of prototype variances.

  • omega (array of shape (d, l)) – Global projection matrix.

Returns:

D – Squared 2-Wasserstein distances in projected space.

Return type:

array of shape (n, p)

prosemble.core.distance.wasserstein2_relevance_distance_matrix(X, means, log_variances, relevances)[source]

Squared 2-Wasserstein distance with feature relevance weighting.

Applies per-feature relevance weights \(\lambda_j\) to the Euclidean component:

\[W_2^2(x, k) = \sum_j \lambda_j (x_j - \mu_{kj})^2 + \sum_j \sigma_{kj}^2\]

where \(\lambda_j = \text{softmax}(r)_j\) ensures non-negative weights that sum to 1.

Parameters:
  • X (array of shape (n, d)) – Data points.

  • means (array of shape (p, d)) – Prototype mean vectors.

  • log_variances (array of shape (p, d)) – Log of prototype variances.

  • relevances (array of shape (d,)) – Raw relevance logits (softmax applied internally).

Returns:

D – Relevance-weighted squared 2-Wasserstein distances.

Return type:

array of shape (n, p)

prosemble.core.distance.gaussian_kernel_matrix(X, Y, sigma)[source]

Compute Gaussian (RBF) kernel matrix.

Formula: \(K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))\)

The Gaussian kernel maps data to infinite-dimensional Hilbert space, enabling non-linear clustering and classification.

Args:

X: Array of shape (n, d) Y: Array of shape (m, d) sigma: Bandwidth parameter (\(\sigma > 0\))

Returns:
K: Array of shape (n, m) where \(K_{ij} \in [0, 1]\).

\(K_{ij} = 1\) when \(X_i = Y_j\); \(K_{ij} o 0\) as \(\|X_i - Y_j\| o \infty\)

Complexity:

Time: O(nmd) Space: O(nm)

Properties:
  • K is positive semi-definite (valid kernel)

  • K is symmetric if X = Y

  • K[i,i] = 1 (self-similarity)

Example:
>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> K = gaussian_kernel_matrix(X, Y, sigma=1.0)
>>> K[0, 0]  # Self-similarity
1.0
>>> K[0, 1] < K[0, 0]  # Decreases with distance
True
Kernel Trick:

For feature map \(\phi\) mapping to infinite-dimensional Hilbert space, :math:`K(x, y) = langlephi(x), phi(y)

angle`. Kernel distance:

\(\|\phi(x) - \phi(y)\|^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y)\) for normalized kernel.

Use Cases:
  • Kernel Fuzzy C-Means (KFCM)

  • Kernel Possibilistic C-Means (KPCM)

  • Support Vector Machines (SVM)

  • Gaussian Processes

Hyperparameter Tuning:
  • Small \(\sigma\): Tight clusters, high sensitivity to noise

  • Large \(\sigma\): Smooth clusters, may underfit

  • Rule of thumb: \(\sigma pprox ext{median}( ext{pairwise\_distances}) / \sqrt{2 \cdot n_ ext{clusters}}\)

Notes:
  • sigma is bandwidth, NOT variance (variance = \(\sigma^2\))

  • For numerical stability, we use maximum() to ensure non-negative

  • JIT-compiled for GPU acceleration

Parameters:
Return type:

Array | ndarray | bool | number

prosemble.core.distance.polynomial_kernel_matrix(X, Y, degree=3, coef0=1.0)[source]

Compute polynomial kernel matrix.

Formula: \(K_{ij} = (X_i^T Y_j + c)^d\) where d is degree and c is coef0.

Parameters:
Returns:

Array of shape (n, m) with polynomial kernel values

Return type:

K

Example

>>> X = jnp.array([[1, 2], [3, 4]])
>>> Y = jnp.array([[1, 0], [0, 1]])
>>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0)

Notes

  • degree=1, coef0=0: Linear kernel (dot product)

  • Higher degree: More complex decision boundaries

  • coef0: Influences importance of lower vs higher order terms

prosemble.core.distance.euclidean_distance(x, y)[source]

Euclidean distance between two vectors.

Parameters:
Returns:

Scalar distance

Return type:

Array | ndarray | bool | number

Example

>>> x = jnp.array([0, 0, 0])
>>> y = jnp.array([1, 1, 1])
>>> d = euclidean_distance(x, y)
>>> d
Array(1.732..., dtype=float32)
prosemble.core.distance.squared_euclidean_distance(x, y)[source]

Squared Euclidean distance between two vectors.

Parameters:
Returns:

Scalar squared distance

Return type:

Array | ndarray | bool | number

prosemble.core.distance.manhattan_distance(x, y)[source]

Manhattan (L1) distance between two vectors.

Parameters:
Returns:

Scalar distance

Return type:

Array | ndarray | bool | number

prosemble.core.distance.lpnorm_distance(x, y, p=2)[source]

Lp-norm distance between two vectors.

Parameters:
Returns:

Scalar distance

Return type:

Array | ndarray | bool | number

prosemble.core.distance.omega_distance(x, y, omega)[source]

Omega (projection-based) distance between two vectors.

Computes \(\| ext{diff} \cdot \Omega\|^2\) where :math:` ext{diff} = x - y`.

Parameters:
Returns:

Scalar squared distance in projected space

Return type:

Array | ndarray | bool | number

prosemble.core.distance.lomega_distance(X, Y, omegas)[source]

Local omega distance with per-prototype projection matrices.

Parameters:
Returns:

Distance matrix of shape (n, m)

Return type:

Array | ndarray | bool | number

prosemble.core.distance.estimate_sigma(X, percentile=50.0)[source]

Estimate sigma for Gaussian kernel using pairwise distances.

Strategy: Use median (or other percentile) of pairwise distances.

Parameters:
Returns:

Estimated bandwidth parameter

Return type:

sigma

Example

>>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10))
>>> sigma = estimate_sigma(X, percentile=50)

Notes

  • Heuristic: \(\sigma = ext{median\_distance} / \sqrt{2 \cdot n_ ext{clusters}}\)

  • For large datasets, use subsample to avoid O(n²) computation

prosemble.core.distance.safe_divide(numerator, denominator, epsilon=1e-10)[source]

Safe division avoiding division by zero.

Parameters:
Returns:

numerator / (denominator + epsilon)

Return type:

Array | ndarray | bool | number

Example

>>> x = jnp.array([1.0, 2.0, 3.0])
>>> y = jnp.array([2.0, 0.0, 1.0])
>>> safe_divide(x, y)
Array([0.5, 2e+09, 3.0], dtype=float32)  # Avoids inf
prosemble.core.distance.batch_squared_euclidean(X, Y)

Compute pairwise squared Euclidean distances.

More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.

Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)

Parameters:
Returns:

Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)

Return type:

D

Complexity:

Time: O(nmd) Space: O(nm)

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D_sq = squared_euclidean_distance_matrix(X, Y)
>>> D_sq[1, 1]  # Squared distance from [1,1] to [2,2]
2.0

Notes

  • Preferred over euclidean when squared distances are sufficient

  • Many algorithms (FCM, PCM) use squared distances directly

  • More numerically stable than squaring euclidean distances

prosemble.core.distance.batch_euclidean(X, Y)

Compute pairwise Euclidean distances between rows of X and Y.

Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).

Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)

Parameters:
Returns:

Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)

Return type:

D

Complexity:

Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix

Example

>>> X = jnp.array([[0, 0], [1, 1], [2, 2]])
>>> Y = jnp.array([[0, 0], [3, 3]])
>>> D = euclidean_distance_matrix(X, Y)
>>> D.shape
(3, 2)
>>> D[0, 0]  # Distance from X[0] to Y[0]
0.0
>>> D[2, 1]  # Distance from X[2] to Y[1]
1.414...

Notes

  • Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives

  • GPU-compatible: All operations are JAX primitives

  • JIT-compiled: First call compiles, subsequent calls are fast

Similarities

Similarity functions for prototype-based learning.

Similarities are the dual of distances: higher values indicate closer/more similar points.

prosemble.core.similarities.gaussian_similarity(distances_sq, variance=1.0)[source]

Convert squared distances to Gaussian similarities.

\(s(d) = \exp(-d^2 / (2 \cdot ext{variance}))\)

Parameters:
  • distances_sq (array) – Squared distances.

  • variance (float) – Variance (sigma^2) of the Gaussian.

Returns:

Similarity values in (0, 1].

Return type:

array

prosemble.core.similarities.cosine_similarity_matrix(X, Y)[source]

Pairwise cosine similarity between rows of X and Y.

:math:`cos(x, y) =

rac{x cdot y}{|x| cdot |y|}`

X : array of shape (n, d) Y : array of shape (m, d)

array of shape (n, m)

Cosine similarities in [-1, 1].

prosemble.core.similarities.euclidean_similarity(X, Y, variance=1.0)[source]

Pairwise Euclidean similarity (Gaussian of Euclidean distance).

Parameters:
  • X (array of shape (n, d))

  • Y (array of shape (m, d))

  • variance (float) – Variance of the Gaussian kernel.

Returns:

Similarity values in (0, 1].

Return type:

array of shape (n, m)

prosemble.core.similarities.rank_scaled_gaussian(distances, lambd=1.0)[source]

Rank-scaled Gaussian similarity.

Combines distance magnitude with rank ordering: closer prototypes (lower rank) receive a stronger signal, while farther ones are exponentially suppressed.

\[s(d, r) = \exp(-\exp(-r / \lambda) \cdot d)\]

where r is the rank of each distance (0 = closest).

Parameters:
  • distances (array of shape (n, m)) – Distance matrix (non-negative).

  • lambd (float) – Rank decay parameter. Larger values give more uniform weighting across ranks; smaller values concentrate on nearest neighbours.

Returns:

Rank-scaled similarity values in (0, 1].

Return type:

array of shape (n, m)

Notes

Used in Probabilistic LVQ (PLVQ) as a conditional distribution P(x|prototype).

Kernel Functions

JAX-based kernel functions for kernel clustering methods.

This module provides GPU-accelerated kernel computations using JAX.

prosemble.core.kernel.gaussian_kernel(x, y, sigma)[source]

Compute Gaussian (RBF) kernel between two vectors.

\(K(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))\)

Parameters:
Returns:

Kernel value (scalar)

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.batch_gaussian_kernel(X, Y, sigma)[source]

Compute Gaussian kernel between two sets of vectors.

Parameters:
Returns:

Kernel matrix, shape (n_samples, m_samples) K[i, j] = K(X[i], Y[j])

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.kernel_distance_squared(X, Y, sigma)[source]

Compute squared distance in feature space.

For Gaussian kernel: \(\|\phi(x) - \phi(y)\|^2 = K(x,x) + K(y,y) - 2K(x,y) = 2(1 - K(x,y))\), since \(K(x,x) = 1\).

Parameters:
Returns:

Squared distances in feature space, shape (n_samples, m_samples)

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.kernel_distance(X, Y, sigma)[source]

Compute distance in feature space.

Parameters:
Returns:

Distances in feature space, shape (n_samples, m_samples)

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.kernel_distance_squared_per_proto(X, W, sigmas)[source]

Squared kernel distance with per-prototype bandwidth.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\|x - w_k\|^2}{2\sigma_k^2}\right)\right)\]

Each prototype \(w_k\) has its own bandwidth \(\sigma_k\).

Parameters:
Returns:

Squared distances in feature space, shape (n_samples, n_prototypes).

Return type:

Array | ndarray | bool | number

References

prosemble.core.kernel.kernel_distance_squared_relevance(X, W, sigmas, relevances)[source]

Squared kernel distance with relevance weighting and per-prototype bandwidth.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2}\right)\right)\]
Parameters:
Returns:

Squared distances in feature space, shape (n_samples, n_prototypes).

Return type:

Array | ndarray | bool | number

References

prosemble.core.kernel.exponential_kernel_distance_squared(X, W, omega_hat)[source]

Squared distance in exponential kernel feature space.

Uses the exponential kernel \(\kappa_{\exp}(v, w, \hat\Lambda) = \exp(v^T \hat\Lambda w)\) where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]

Note: \(\kappa(v, v) \neq 1\) for the exponential kernel, so the full three-term formula is required (not the 2(1-K) simplification).

Parameters:
  • X (Array | ndarray | bool | number) – Data matrix, shape (n_samples, n_features).

  • W (Array | ndarray | bool | number) – Prototype matrix, shape (n_prototypes, n_features).

  • omega_hat (Array | ndarray | bool | number) – Transformation matrix, shape (n_features, latent_dim). The kernel matrix is \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

Returns:

Squared distances in feature space, shape (n_samples, n_prototypes).

Return type:

Array | ndarray | bool | number

References

Loss Functions

Loss functions for prototype-based learning.

All loss functions are differentiable by jax.grad and JIT-compatible. They use jnp.where masking (not boolean indexing) for d+/d- extraction.

prosemble.core.losses.glvq_loss(distances, target_labels, prototype_labels, margin=0.0)[source]

Generalized LVQ loss.

:math:`mu_i =

rac{d^+_i - d^-_i}{d^+_i + d^-_i}`

distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) margin : float

Margin added to mu before transfer.

scalar

Mean loss over samples.

prosemble.core.losses.glvq_loss_with_transfer(distances, target_labels, prototype_labels, transfer_fn=<PjitFunction of <function identity>>, margin=0.0, beta=10.0)[source]

GLVQ loss with configurable transfer function.

:math:` ext{loss} = ext{mean}(f(mu + ext{margin}, eta))`

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • transfer_fn (callable) – Activation function (identity, sigmoid_beta, swish_beta).

  • margin (float)

  • beta (float) – Transfer function parameter.

Return type:

scalar

prosemble.core.losses.lvq1_loss(distances, target_labels, prototype_labels)[source]

LVQ1 loss: d+ when correct, -d- when wrong.

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

Return type:

scalar

prosemble.core.losses.lvq21_loss(distances, target_labels, prototype_labels)[source]

LVQ2.1 loss: d+ - d- (unnormalized).

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

Return type:

scalar

prosemble.core.losses.nllr_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]

Negative Log-Likelihood Ratio loss (for SLVQ).

:math:` ext{loss} = -log(P( ext{correct}) / P( ext{wrong}))`

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • sigma (float) – Bandwidth of Gaussian mixture.

Return type:

scalar

prosemble.core.losses.rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]

Robust Soft LVQ loss (for RSLVQ).

:math:` ext{loss} = -log(P( ext{correct}) / P( ext{all}))`

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • sigma (float)

Return type:

scalar

prosemble.core.losses.ng_rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0, gamma=1.0)[source]

RSLVQ loss with Neural Gas rank-based neighborhood cooperation.

Combines Gaussian mixture prototype probabilities with NG rank weights to create a neighborhood-cooperative probabilistic assignment.

Gaussian: \(p(k|x) = \exp(-d_k / 2\sigma^2) / \sum_j \exp(-d_j / 2\sigma^2)\)

NG weights: \(h_k = \exp(- ext{rank}_k / \gamma) / \sum_j \exp(- ext{rank}_j / \gamma)\)

Combined: \(w_k = p(k|x) \cdot h_k / \sum_j p(j|x) \cdot h_j\)

The loss is \(-\log(\sum_{k \in ext{correct}} w_k)\).

Parameters:
  • distances (array of shape (n, p)) – Squared distances from samples to prototypes.

  • target_labels (array of shape (n,)) – True class labels for samples.

  • prototype_labels (array of shape (p,)) – Class labels assigned to prototypes.

  • sigma (float) – Bandwidth of Gaussian mixture.

  • gamma (float) – Neural Gas neighborhood range.

Returns:

Mean negative log-likelihood with NG cooperation.

Return type:

scalar

prosemble.core.losses.cross_entropy_lvq_loss(distances, target_labels, prototype_labels, n_classes)[source]

Cross-entropy LVQ loss (for CELVQ).

  1. Min distances per class via masking

  2. Negate to get logits (closer = higher)

  3. Cross-entropy against true labels

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • n_classes (int)

Return type:

scalar

prosemble.core.losses.margin_loss(y_pred, y_true_one_hot, margin=0.3)[source]

Margin loss for CBC.

:math:` ext{loss} = ext{ReLU}(max( ext{wrong}) - ext{correct} + ext{margin})`

Parameters:
  • y_pred (array of shape (n, n_classes)) – Predicted class probabilities.

  • y_true_one_hot (array of shape (n, n_classes)) – One-hot encoded true labels.

  • margin (float)

Return type:

scalar

prosemble.core.losses.neural_gas_energy(distances, lam)[source]

Neural Gas energy function.

\[E = \sum_k h( ext{rank}_k, \lambda) \cdot d(x, w_k), \quad h( ext{rank}, \lambda) = \exp(- ext{rank} / \lambda)\]
Parameters:
  • distances (array of shape (n, p))

  • lam (float) – Neighborhood range parameter.

Return type:

scalar

Activations

Transfer/activation functions for prototype-based learning.

These functions are used to shape the GLVQ loss (mu values) before summation, controlling the optimization landscape.

prosemble.core.activations.identity(x, beta=0.0)[source]

Identity activation (passthrough).

Parameters:
  • x (array) – Input values.

  • beta (float) – Ignored. Present for API consistency.

Returns:

Same as input.

Return type:

array

prosemble.core.activations.sigmoid_beta(x, beta=10.0)[source]

Sigmoid activation with steepness parameter.

f(x) = 1 / (1 + exp(-beta * x))

Parameters:
  • x (array) – Input values.

  • beta (float) – Steepness parameter. Higher values give sharper transition.

Returns:

Sigmoid-transformed values in (0, 1).

Return type:

array

prosemble.core.activations.swish_beta(x, beta=10.0)[source]

Swish activation with steepness parameter.

f(x) = x * sigmoid(beta * x)

Parameters:
  • x (array) – Input values.

  • beta (float) – Steepness parameter.

Returns:

Swish-transformed values.

Return type:

array

prosemble.core.activations.get_activation(name)[source]

Get activation function by name.

Parameters:

name (str or callable) – Name of activation (‘identity’, ‘sigmoid_beta’, ‘swish_beta’) or a callable.

Returns:

The activation function.

Return type:

callable

Competitions

Competition mechanisms for prototype-based classification.

These functions determine class predictions from distance matrices and prototype labels using different strategies.

prosemble.core.competitions.wtac(distances, prototype_labels)[source]

Winner-Takes-All Competition.

Assigns each sample the label of the closest prototype.

Parameters:
  • distances (array of shape (n_samples, n_prototypes)) – Distance matrix.

  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

Returns:

Predicted labels.

Return type:

array of shape (n_samples,)

prosemble.core.competitions.knnc(distances, prototype_labels, k=1, n_classes=None)[source]

K-Nearest Neighbors Competition.

Assigns each sample the majority label among k closest prototypes.

Parameters:
  • distances (array of shape (n_samples, n_prototypes)) – Distance matrix.

  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • k (int) – Number of neighbors.

  • n_classes (int or None) – Number of classes. If None, inferred from prototype_labels.

Returns:

Predicted labels.

Return type:

array of shape (n_samples,)

prosemble.core.competitions.cbcc(detections, reasonings)[source]

Classification-By-Components Competition.

Computes class probability distributions using component detections and reasoning matrices.

Parameters:
  • detections (array of shape (n_samples, n_components)) – Similarity/detection scores for each component.

  • reasonings (array of shape (n_components, n_classes, 2)) – Reasoning matrices. Last dim: [positive, negative_raw].

Returns:

Class probability distributions.

Return type:

array of shape (n_samples, n_classes)

Initializers

Prototype and parameter initializers for prototype-based learning.

These functions generate initial prototypes, labels, and transformation matrices for supervised and unsupervised models.

prosemble.core.initializers.stratified_selection_init(X, y, n_per_class, key)[source]

Initialize prototypes by randomly selecting samples per class.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • y (array of shape (n_samples,)) – Training labels.

  • n_per_class (int, list, or dict) – Number of prototypes per class. - int: same count for all classes. - list: index i gives the count for class i, e.g. [2, 2, 1]. - dict: maps class label to count, e.g. {0: 2, 1: 2, 2: 1}.

  • key (jax.random.PRNGKey) – Random key.

Returns:

  • prototypes (array of shape (n_prototypes, n_features))

  • prototype_labels (array of shape (n_prototypes,))

prosemble.core.initializers.stratified_mean_init(X, y)[source]

Initialize prototypes at the mean of each class.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • y (array of shape (n_samples,)) – Training labels.

Returns:

  • prototypes (array of shape (n_classes, n_features))

  • prototype_labels (array of shape (n_classes,))

prosemble.core.initializers.random_normal_init(n_prototypes, n_features, key, mean=0.0, std=1.0)[source]

Initialize prototypes from a normal distribution.

Parameters:
  • n_prototypes (int)

  • n_features (int)

  • key (jax.random.PRNGKey)

  • mean (float)

  • std (float)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.identity_omega_init(n_features, n_dims=None)[source]

Initialize omega as an identity (or truncated identity) matrix.

Parameters:
  • n_features (int) – Input dimensionality.

  • n_dims (int, optional) – Projection dimensionality. Defaults to n_features (square).

Return type:

array of shape (n_features, n_dims)

prosemble.core.initializers.random_omega_init(n_features, n_dims, key)[source]

Initialize omega as a random orthogonal matrix via QR decomposition.

Parameters:
  • n_features (int) – Input dimensionality.

  • n_dims (int) – Projection dimensionality.

  • key (jax.random.PRNGKey)

Return type:

array of shape (n_features, n_dims)

prosemble.core.initializers.uniform_init(n_prototypes, n_features, key, low=0.0, high=1.0)[source]

Initialize prototypes from a uniform distribution.

Parameters:
  • n_prototypes (int)

  • n_features (int)

  • key (jax.random.PRNGKey)

  • low (float)

  • high (float)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.zeros_init(n_prototypes, n_features)[source]

Initialize prototypes as zeros.

Useful for checkpoint loading where shapes must be pre-allocated before restoring saved values.

Parameters:
  • n_prototypes (int)

  • n_features (int)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.ones_init(n_prototypes, n_features)[source]

Initialize prototypes as ones.

Parameters:
  • n_prototypes (int)

  • n_features (int)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.fill_value_init(n_prototypes, n_features, value=0.0)[source]

Initialize prototypes filled with a constant value.

Parameters:
  • n_prototypes (int)

  • n_features (int)

  • value (float) – Fill value.

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.selection_init(X, n_prototypes, key)[source]

Initialize prototypes by uniformly sampling from data (classless).

Suitable for unsupervised models like Neural Gas, SOM, and fuzzy clustering where no labels are available.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • n_prototypes (int) – Number of prototypes to select.

  • key (jax.random.PRNGKey)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.mean_init(X, n_prototypes)[source]

Initialize all prototypes at the data mean (classless).

Suitable for unsupervised models. All prototypes start at the global mean and diverge during training.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • n_prototypes (int) – Number of prototypes.

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.literal_init(values)[source]

Initialize prototypes from literal values.

Used for warm-starting from another model’s prototypes or from user-provided values.

Parameters:

values (array-like of shape (n_prototypes, n_features)) – Literal prototype values.

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.stratified_noise_init(X, y, n_per_class, key, noise_std=0.1)[source]

Initialize prototypes by selecting samples per class and adding noise.

Combines stratified selection with Gaussian noise injection for diverse initial prototypes.

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,))

  • n_per_class (int, list, or dict) – Number of prototypes per class (same formats as stratified_selection_init).

  • key (jax.random.PRNGKey)

  • noise_std (float) – Standard deviation of Gaussian noise to add.

Returns:

  • prototypes (array of shape (n_prototypes, n_features))

  • prototype_labels (array of shape (n_prototypes,))

prosemble.core.initializers.pca_omega_init(X, n_dims)[source]

Initialize omega using PCA directions from training data.

The top-n_dims principal components become the columns of omega, providing a data-driven initialization for metric learning models.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • n_dims (int) – Number of principal components (projection dimensionality).

Returns:

omega

Return type:

array of shape (n_features, n_dims)

prosemble.core.initializers.class_conditional_mean_init(X, y, n_per_class)[source]

Initialize prototypes at class means, replicated per n_per_class.

When n_per_class > 1, each class gets multiple prototypes all initialized at the class mean (they will diverge during training).

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,))

  • n_per_class (int, list, or dict) – Number of prototypes per class.

Returns:

  • prototypes (array of shape (n_prototypes, n_features))

  • prototype_labels (array of shape (n_prototypes,))

prosemble.core.initializers.random_reasonings_init(n_components, n_classes, key)[source]

Initialize CBC reasoning matrices randomly.

Parameters:
  • n_components (int)

  • n_classes (int)

  • key (jax.random.PRNGKey)

Returns:

reasonings

Return type:

array of shape (n_components, n_classes, 2)

prosemble.core.initializers.pure_positive_reasonings_init(n_components, n_classes, key=None)[source]

Initialize CBC reasoning matrices with pure positive evidence.

Each component maps to exactly one class with high positive evidence and low negative evidence.

Parameters:
  • n_components (int)

  • n_classes (int)

  • key (ignored (included for API compatibility))

Returns:

reasonings

Return type:

array of shape (n_components, n_classes, 2)

Pooling

Stratified pooling operations for prototype-based learning.

These functions aggregate per-prototype distances into per-class distances, grouping by prototype label. Essential for GLVQ and CELVQ.

prosemble.core.pooling.stratified_min_pooling(distances, prototype_labels, n_classes)[source]

Per-class minimum distance pooling.

For each sample and each class, returns the minimum distance to any prototype of that class.

Parameters:
  • distances (array of shape (n_samples, n_prototypes)) – Distance matrix.

  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • n_classes (int) – Number of classes.

Returns:

Minimum distance to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

prosemble.core.pooling.stratified_sum_pooling(distances, prototype_labels, n_classes)[source]

Per-class sum distance pooling.

Parameters:
  • distances (array of shape (n_samples, n_prototypes))

  • prototype_labels (array of shape (n_prototypes,))

  • n_classes (int)

Returns:

Sum of distances to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

prosemble.core.pooling.stratified_max_pooling(distances, prototype_labels, n_classes)[source]

Per-class maximum distance pooling.

Parameters:
  • distances (array of shape (n_samples, n_prototypes))

  • prototype_labels (array of shape (n_prototypes,))

  • n_classes (int)

Returns:

Maximum distance to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

prosemble.core.pooling.stratified_prod_pooling(distances, prototype_labels, n_classes)[source]

Per-class product distance pooling.

Uses log-sum-exp for numerical stability.

Parameters:
  • distances (array of shape (n_samples, n_prototypes))

  • prototype_labels (array of shape (n_prototypes,))

  • n_classes (int)

Returns:

Product of distances to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

Quantization

Shared mixins for model base classes.

class prosemble.core.quantization.MetadataCollectorMixin[source]

Mixin that auto-collects _hyperparams and _fitted_array_names from MRO.

Any base class that declares per-class _hyperparams and _fitted_array_names tuples can inherit this mixin to get _all_hyperparams and _all_fitted_array_names aggregated automatically across the entire class hierarchy.

get_params(deep=True)[source]

Get parameters for this estimator.

Follows the sklearn estimator protocol by inspecting __init__ signatures across the MRO.

Parameters:

deep (bool, default=True) – Ignored (present for sklearn compatibility).

Returns:

Parameter names mapped to their values.

Return type:

dict

set_params(**params)[source]

Set parameters on this estimator.

Parameters:

**params – Estimator parameters to set.

Return type:

self

class prosemble.core.quantization.QuantizationMixin[source]

Mixin for quantizing/dequantizing fitted model parameters.

Supports float16, bfloat16, and int8 (with per-tensor scale factors).

Subclasses override _get_quantizable_attrs to declare which fitted attributes are eligible for quantization.

quantize(dtype='float16')[source]

Quantize model parameters to lower precision.

Post-training quantization for smaller model size and faster inference.

Parameters:

dtype (str) – Target precision: ‘float16’, ‘bfloat16’, or ‘int8’.

Return type:

self

dequantize()[source]

Restore model parameters to float32.

Return type:

self

property is_quantized: bool

Whether model parameters are currently quantized.

property quantized_dtype: str | None

Current quantization dtype, or None if not quantized.

Utilities

JAX utility functions for data preprocessing and manipulation.

Replaces common sklearn/numpy operations with JAX-native implementations.

prosemble.core.utils.train_test_split_jax(X, y=None, test_size=0.2, random_seed=42)[source]

JAX-native train/test split.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Feature matrix

  • y (array-like of shape (n_samples,), optional) – Labels

  • test_size (float, default=0.2) – Proportion of dataset for test set (0.0 to 1.0)

  • random_seed (int, default=42) – Random seed for reproducibility

Returns:

Split arrays. If y is provided, returns 4 arrays, else 2.

Return type:

X_train, X_test[, y_train, y_test]

Examples

>>> X = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> y = jnp.array([0, 1, 0, 1])
>>> X_train, X_test, y_train, y_test = train_test_split_jax(X, y, test_size=0.5)
prosemble.core.utils.standardize(X, mean=None, std=None)[source]

Standardize features (zero mean, unit variance).

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Data to standardize

  • mean (array-like of shape (n_features,), optional) – Pre-computed mean (for test data)

  • std (array-like of shape (n_features,), optional) – Pre-computed std (for test data)

Returns:

  • X_scaled (array-like) – Standardized data

  • mean (array-like) – Mean used for scaling

  • std (array-like) – Std used for scaling

Return type:

tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]

Examples

>>> X_train = jnp.array([[1, 2], [3, 4], [5, 6]])
>>> X_scaled, mean, std = standardize(X_train)
>>> # For test data
>>> X_test_scaled, _, _ = standardize(X_test, mean=mean, std=std)
prosemble.core.utils.min_max_scale(X, min_val=None, max_val=None)[source]

Scale features to [0, 1] range.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Data to scale

  • min_val (array-like of shape (n_features,), optional) – Pre-computed min (for test data)

  • max_val (array-like of shape (n_features,), optional) – Pre-computed max (for test data)

Returns:

  • X_scaled (array-like) – Scaled data

  • min_val (array-like) – Min values used

  • max_val (array-like) – Max values used

Return type:

tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]

prosemble.core.utils.pca_jax(X, n_components=2)[source]

Principal Component Analysis using JAX.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Data matrix

  • n_components (int, default=2) – Number of principal components

Returns:

  • X_transformed (array-like of shape (n_samples, n_components)) – Transformed data

  • components (array-like of shape (n_components, n_features)) – Principal components

Return type:

tuple[Array | ndarray | bool | number, Array | ndarray | bool | number]

Examples

>>> X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> X_pca, components = pca_jax(X, n_components=2)
prosemble.core.utils.accuracy_score_jax(y_true, y_pred)[source]

Compute classification accuracy.

Parameters:
  • y_true (array-like of shape (n_samples,)) – True labels

  • y_pred (array-like of shape (n_samples,)) – Predicted labels

Returns:

accuracy – Accuracy score

Return type:

float

Examples

>>> y_true = jnp.array([0, 1, 1, 0])
>>> y_pred = jnp.array([0, 1, 0, 0])
>>> accuracy_score_jax(y_true, y_pred)
0.75
prosemble.core.utils.confusion_matrix_jax(y_true, y_pred, n_classes)[source]

Compute confusion matrix.

Parameters:
  • y_true (array-like of shape (n_samples,)) – True labels

  • y_pred (array-like of shape (n_samples,)) – Predicted labels

  • n_classes (int) – Number of classes

Returns:

conf_matrix – Confusion matrix

Return type:

array-like of shape (n_classes, n_classes)

Examples

>>> y_true = jnp.array([0, 1, 2, 0, 1, 2])
>>> y_pred = jnp.array([0, 2, 2, 0, 0, 1])
>>> confusion_matrix_jax(y_true, y_pred, n_classes=3)
prosemble.core.utils.shuffle_jax(X, y=None, random_seed=42)[source]

Shuffle arrays in unison.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Feature matrix

  • y (array-like of shape (n_samples,), optional) – Labels

  • random_seed (int, default=42) – Random seed

Returns:

Shuffled arrays

Return type:

X_shuffled[, y_shuffled]

Examples

>>> X = jnp.array([[1, 2], [3, 4], [5, 6]])
>>> y = jnp.array([0, 1, 0])
>>> X_shuf, y_shuf = shuffle_jax(X, y, random_seed=42)
prosemble.core.utils.k_fold_split_jax(n_samples, n_folds=5, random_seed=42)[source]

Generate K-fold cross-validation indices.

Parameters:
  • n_samples (int) – Number of samples

  • n_folds (int, default=5) – Number of folds

  • random_seed (int, default=42) – Random seed

Yields:

train_indices, test_indices – Indices for each fold

Examples

>>> for train_idx, test_idx in k_fold_split_jax(100, n_folds=5):
...     X_train, X_test = X[train_idx], X[test_idx]
prosemble.core.utils.orthogonalize(matrix)[source]

Orthogonalize a matrix via polar decomposition (SVD).

Given a matrix A of shape (d, s), computes the closest orthogonal matrix Q such that Q^T Q = I, using the polar factor:

U, S, V^T = SVD(A) Q = U @ V^T

Supports batched input via jax.vmap.

Parameters:

matrix (array of shape (d, s) or (n, d, s)) – Matrix or batch of matrices to orthogonalize. For batched input, use jax.vmap(orthogonalize).

Returns:

Q – Orthogonal matrix (columns are orthonormal).

Return type:

array of same shape as input

prosemble.core.utils.class_priors(labels, n_classes=None)[source]

Compute class prior probabilities from labels.

P(class=k) = n_k / n

Parameters:
  • labels (array of shape (n_samples,)) – Integer class labels.

  • n_classes (int, optional) – Number of classes. Inferred from labels if not provided.

Returns:

priors – Prior probability for each class, sums to 1.

Return type:

array of shape (n_classes,)

prosemble.core.utils.prototype_priors(prototype_labels, n_classes=None)[source]

Compute class priors from prototype label distribution.

Used in probabilistic LVQ models where the prior over prototypes is uniform (1/n_prototypes) and the class prior is the fraction of prototypes assigned to each class.

P(class=k) = |{j : label_j = k}| / n_prototypes

Parameters:
  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • n_classes (int, optional) – Number of classes. Inferred if not provided.

Returns:

priors – Class prior probabilities, sums to 1.

Return type:

array of shape (n_classes,)

prosemble.core.utils.uniform_prototype_prior(n_prototypes)[source]

Uniform prior over prototypes: P(prototype_j) = 1/n.

This is the standard prior used in probabilistic LVQ (SLVQ/RSLVQ).

Parameters:

n_prototypes (int) – Number of prototypes.

Returns:

prior – Uniform probability vector, sums to 1.

Return type:

array of shape (n_prototypes,)

Pipeline

class prosemble.core.pipeline.NotFittedError[source]

Raised when calling transform/predict on an unfitted estimator.

class prosemble.core.pipeline.TransformerMixin[source]

Mixin providing fit_transform.

fit_transform(X, y=None)[source]

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.StandardScaler[source]

Standardize features to zero mean and unit variance.

Wraps prosemble.core.utils.standardize().

Examples

>>> scaler = StandardScaler()
>>> X_scaled = scaler.fit_transform(X_train)
>>> X_test_scaled = scaler.transform(X_test)
fit(X, y=None)[source]

Compute mean and std from training data.

Return type:

Self

transform(X)[source]

Standardize using stored mean and std.

Return type:

Array | ndarray | bool | number

get_params(deep=True)[source]
set_params(**params)[source]
fit_transform(X, y=None)

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.MinMaxScaler[source]

Scale features to [0, 1] range.

Wraps prosemble.core.utils.min_max_scale().

Examples

>>> scaler = MinMaxScaler()
>>> X_scaled = scaler.fit_transform(X_train)
fit(X, y=None)[source]

Compute min and max from training data.

Return type:

Self

transform(X)[source]

Scale using stored min and max.

Return type:

Array | ndarray | bool | number

get_params(deep=True)[source]
set_params(**params)[source]
fit_transform(X, y=None)

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.PCA(n_components=2)[source]

Principal Component Analysis.

Wraps prosemble.core.utils.pca_jax().

Parameters:

n_components (int, default=2) – Number of principal components to keep.

Examples

>>> pca = PCA(n_components=2)
>>> X_reduced = pca.fit_transform(X)
fit(X, y=None)[source]

Compute principal components from training data.

Return type:

Self

transform(X)[source]

Project data onto principal components.

Return type:

Array | ndarray | bool | number

get_params(deep=True)[source]
set_params(**params)[source]
fit_transform(X, y=None)

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.Pipeline(steps)[source]

Chain transformers with a final estimator.

All operations use pure JAX arrays — no numpy roundtrips.

Parameters:

steps (list of (name, estimator) tuples) – Sequence of transforms with a final estimator. All but the last must implement fit() and transform(). The last step can be any estimator or transformer.

Examples

>>> pipe = Pipeline([
...     ('scaler', StandardScaler()),
...     ('pca', PCA(n_components=2)),
...     ('model', GLVQ(n_prototypes_per_class=1, max_iter=50)),
... ])
>>> pipe.fit(X_train, y_train)
>>> preds = pipe.predict(X_test)
fit(X, y=None, **fit_params)[source]

Fit all steps.

Calls fit_transform on intermediate steps, fit on the final step.

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,), optional)

  • **fit_params (forwarded to the final estimator's fit().)

Return type:

Self

predict(X)[source]

Transform X through pipeline, then predict.

predict_proba(X)[source]

Transform X through pipeline, then predict_proba.

transform(X)[source]

Transform X through all steps (including last if it has transform).

fit_transform(X, y=None, **fit_params)[source]

Fit and transform.

get_params(deep=True)[source]

Get pipeline parameters.

If deep=True, includes nested estimator params as name__param.

set_params(**params)[source]

Set pipeline parameters.

Supports nested params: model__lr=0.05 sets lr on step ‘model’.

Model Selection

prosemble.core.model_selection.clone(estimator, **override_params)[source]

Create a fresh (unfitted) copy of an estimator.

Parameters:
  • estimator (object) – Estimator with get_params() protocol.

  • **override_params – Parameters to override in the clone.

Returns:

new_estimator – Fresh, unfitted instance.

Return type:

same type as estimator

Examples

>>> model = GLVQ(n_prototypes_per_class=2, lr=0.01)
>>> model2 = clone(model, lr=0.05)
prosemble.core.model_selection.cross_val_score(estimator, X, y=None, cv=5, scoring='accuracy', random_seed=42)[source]

Evaluate estimator with cross-validation.

Parameters:
  • estimator (object) – Estimator with fit/predict and get_params.

  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,), optional) – Required for supervised estimators.

  • cv (int, default=5) – Number of cross-validation folds.

  • scoring (str or callable, default='accuracy') – ‘accuracy’ or callable scorer(y_true, y_pred) -> float. For unsupervised without y: scorer(estimator, X_test) -> float.

  • random_seed (int, default=42)

Returns:

scores – Score for each fold.

Return type:

jnp.ndarray of shape (cv,)

Examples

>>> scores = cross_val_score(GLVQ(max_iter=30), X, y, cv=5)
>>> print(f"Mean: {scores.mean():.3f} +/- {scores.std():.3f}")
class prosemble.core.model_selection.GridSearchCV(estimator, param_grid, cv=5, scoring='accuracy', random_seed=42, refit=True, verbose=0)[source]

Exhaustive search over a parameter grid with cross-validation.

Parameters:
  • estimator (object) – Base estimator with fit/predict and get_params/set_params. Can be a Pipeline or any prosemble model.

  • param_grid (dict) – Maps parameter names to lists of values to try. For Pipeline steps, use step_name__param notation.

  • cv (int, default=5) – Number of cross-validation folds.

  • scoring (str or callable, default='accuracy') – ‘accuracy’ or callable scorer(y_true, y_pred) -> float.

  • random_seed (int, default=42)

  • refit (bool, default=True) – If True, refit the best model on the full dataset after search.

  • verbose (int, default=0) – 0=silent, 1=per-combo summary, 2=per-fold detail.

best_params_

Parameters of the best model.

Type:

dict

best_score_

Mean CV score of the best model.

Type:

float

best_estimator_

Fitted estimator with best params (only if refit=True).

Type:

object

cv_results_

Keys: ‘params’, ‘mean_score’, ‘std_score’, ‘fold_scores’, ‘rank’.

Type:

dict

Examples

>>> gs = GridSearchCV(
...     GLVQ(max_iter=30),
...     {'n_prototypes_per_class': [1, 2], 'lr': [0.01, 0.05]},
...     cv=3,
... )
>>> gs.fit(X, y)
>>> print(gs.best_params_, gs.best_score_)
fit(X, y=None)[source]

Run grid search with cross-validation.

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,), optional)

Return type:

self

predict(X)[source]

Predict using best estimator.

predict_proba(X)[source]

Predict probabilities using best estimator.

Datasets

JAX-compatible dataset module.

Provides dataset loaders that return JAX arrays instead of NumPy arrays, optimized for use with JAX-based clustering models.

class prosemble.datasets.dataset.DATASET_JAX(input_data, labels)[source]

JAX-compatible dataset container.

Parameters:
input_data

Feature matrix as JAX array

Type:

jax.numpy.ndarray

labels

Labels as JAX array

Type:

jax.numpy.ndarray

input_data: Array
labels: Array
to_numpy()[source]

Convert JAX arrays back to NumPy arrays.

prosemble.datasets.dataset.iris_dataset_jax(dtype=None)[source]

Load Iris dataset as JAX arrays.

Parameters:

dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (150 samples, 4 features, 3 classes)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import iris_dataset_jax
>>> dataset = iris_dataset_jax()
>>> dataset.input_data.shape
(150, 4)
prosemble.datasets.dataset.breast_cancer_dataset_jax(dtype=None)[source]

Load Wisconsin Breast Cancer dataset as JAX arrays.

Parameters:

dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (569 samples, 30 features, 2 classes)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import breast_cancer_dataset_jax
>>> dataset = breast_cancer_dataset_jax()
>>> dataset.input_data.shape
(569, 30)
prosemble.datasets.dataset.moons_dataset_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]

Generate two interleaving half circles (moons) as JAX arrays.

Parameters:
  • n_samples (int, default=150) – Total number of points generated

  • noise (float, optional) – Standard deviation of Gaussian noise added to data

  • random_state (int, optional) – Random seed for reproducibility

  • dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (n_samples, 2 features, 2 classes)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import moons_dataset_jax
>>> dataset = moons_dataset_jax(n_samples=200, noise=0.1, random_state=42)
>>> dataset.input_data.shape
(200, 2)
prosemble.datasets.dataset.blobs_dataset_jax(n_samples=None, centers=None, cluster_std=None, random_state=None, dtype=None)[source]

Generate isotropic Gaussian blobs as JAX arrays.

Parameters:
  • n_samples (list, default=[120, 80]) – Number of samples per cluster

  • centers (list, default=[[0.0, 0.0], [2.0, 2.0]]) – Centers of the clusters

  • cluster_std (list, default=[1.2, 0.5]) – Standard deviation of clusters

  • random_state (int, optional) – Random seed for reproducibility

  • dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (sum(n_samples), 2 features)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import blobs_dataset_jax
>>> dataset = blobs_dataset_jax(random_state=42)
>>> dataset.input_data.shape
(200, 2)
class prosemble.datasets.dataset.DATA_JAX(random=4, sample_size=1000)[source]

JAX dataset collection with convenient property access.

Parameters:
  • random (int, default=4) – Random seed for dataset generation

  • sample_size (int, default=1000) – Default sample size (not currently used)

  • dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Examples

>>> from prosemble.datasets import DATA_JAX
>>> data = DATA_JAX(random=42)
>>> moons = data.S_1  # Moons dataset
>>> blobs = data.S_2  # Blobs dataset
>>> cancer = data.breast_cancer  # Breast cancer dataset
random: int = 4
sample_size: int = 1000
dtype

alias of float32

property S_1: DATASET_JAX

Moons dataset.

property S_2: DATASET_JAX

Blobs dataset.

property iris: DATASET_JAX

Iris dataset.

property breast_cancer: DATASET_JAX

Breast cancer dataset.

property moons: DATASET_JAX

Alias for S_1 (moons dataset).

property blobs: DATASET_JAX

Alias for S_2 (blobs dataset).

prosemble.datasets.dataset.load_iris_jax(dtype=None)[source]

Quick loader for iris dataset.

Return type:

DATASET_JAX

prosemble.datasets.dataset.load_breast_cancer_jax(dtype=None)[source]

Quick loader for breast cancer dataset.

Return type:

DATASET_JAX

prosemble.datasets.dataset.load_moons_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]

Quick loader for moons dataset.

Parameters:
  • n_samples (int)

  • noise (float | None)

  • random_state (int | None)

Return type:

DATASET_JAX

prosemble.datasets.dataset.load_blobs_jax(random_state=None, dtype=None)[source]

Quick loader for blobs dataset.

Parameters:

random_state (int | None)

Return type:

DATASET_JAX

prosemble.datasets.dataset.DATA

alias of DATA_JAX

ONNX Export

ONNX export for prosemble prototype-based models.

Converts a fitted model’s predict function into an ONNX graph. Only supports models whose distance function can be expressed with standard ONNX operators. Unsupported models raise NotImplementedError with a clear message.

Supported distance functions:

  • squared_euclidean_distance_matrix

  • euclidean_distance_matrix

  • manhattan_distance_matrix

  • omega_distance_matrix (global projection matrix)

  • lomega_distance_matrix (per-prototype local matrices)

  • tangent_distance_matrix (per-prototype tangent subspace)

  • relevance_weighted (per-feature relevance weighting)

Supported decision patterns:

  • WTAC (supervised classification)

  • ArgMin (unsupervised clustering)

  • One-class hard nearest (OCGLVQ family)

  • One-class Gaussian soft (OCRSLVQ family)

  • One-class Gaussian+NG soft (OCRSLVQ_NG family)

  • SVQ-OCC response model (SVQOCC family)

  • CBC reasoning (CBC, ImageCBC)

  • PLVQ Gaussian mixture soft assignment

Supported encoder models:

  • MLP encoder (SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN, PLVQ)

  • CNN encoder (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC)

Not supported:

  • gaussian_kernel_matrix, polynomial_kernel_matrix

  • Riemannian manifold distances (logm, expm have no ONNX equivalent)

prosemble.core.onnx_export.export_onnx(model, batch_size=1, opset_version=17, path=None, reject_threshold=None)[source]

Export a fitted model’s predict function to ONNX format.

Builds an ONNX graph that reproduces the model’s predict() output. Supports supervised (WTAC), unsupervised (ArgMin), one-class (threshold-based), SVQ-OCC (response model), CBC (reasoning matrices), PLVQ (Gaussian mixture), encoder models (MLP/CNN backbone), Differentiating Kernel models (Gaussian, relevance-weighted, and exponential kernels), and Riemannian models on SO(n)/Grassmannian manifolds (88 of 102 models total).

Parameters:
  • model (SupervisedPrototypeModel or UnsupervisedPrototypeModel) – A fitted prosemble model.

  • batch_size (int) – Fixed batch dimension for the input. Use -1 for dynamic batch size (ONNX symbolic dimension).

  • opset_version (int) – ONNX opset version. Default: 17.

  • path (str, optional) – If provided, save the ONNX model to this file path.

  • reject_threshold (float, optional) – If provided, enables reject option for supervised models. Samples with confidence below this threshold are assigned prediction -1 (rejected). The exported model produces two outputs: predictions (INT64, with -1 for rejected) and confidence (FLOAT, in [-1, 1]). Only supported for supervised models (WTAC decision).

Returns:

The exported ONNX model.

Return type:

onnx.ModelProto

Raises:
  • NotImplementedError – If the model’s distance function or decision pattern is not supported.

  • ImportError – If the onnx package is not installed.

  • ValueError – If reject_threshold is used with a non-supervised model.

Manifolds

Riemannian manifold primitives for prototype learning.

Provides geodesic distance, exponential map, and logarithmic map for manifolds commonly arising in machine learning:

  • SO(n): Special orthogonal group (rotation matrices)

  • SPD(n): Symmetric positive definite matrices

  • Grassmannian Gr(n, k): k-dimensional subspaces of R^n

  • HyperbolicPoincare(d): Poincare ball model of hyperbolic space

All operations are JIT-compilable and vmap-compatible.

References

Schwarz, Psenickova, Villmann, Röhrbein (2026). Topology-Preserving Prototype Learning on Riemannian Manifolds. ESANN 2026.

Ganea, O., Becigneul, G., & Hofmann, T. (2018). Hyperbolic Neural Networks. NeurIPS 2018.

prosemble.core.manifolds.logm_safe(A)[source]

Matrix logarithm via funm, with numerical safety.

Uses JAX’s funm to compute logm. For complex eigenvalues (e.g. rotation matrices), operates in complex domain and returns the real part.

Parameters:

A (array of shape (..., n, n))

Returns:

logA

Return type:

array of shape (…, n, n)

prosemble.core.manifolds.sqrt_spd(A)[source]

Matrix square root for symmetric positive definite matrices.

Uses eigendecomposition: \(A^{1/2} = V \operatorname{diag}(\sqrt{\lambda}) V^T\).

Parameters:

A (array of shape (n, n), symmetric positive definite)

Returns:

A_sqrt

Return type:

array of shape (n, n)

prosemble.core.manifolds.inv_sqrt_spd(A)[source]

Inverse matrix square root for symmetric positive definite matrices.

Uses eigendecomposition: \(A^{-1/2} = V \operatorname{diag}(1/\sqrt{\lambda}) V^T\).

Parameters:

A (array of shape (n, n), symmetric positive definite)

Returns:

A_inv_sqrt

Return type:

array of shape (n, n)

class prosemble.core.manifolds.SO(n)[source]

Special orthogonal group SO(n): rotation matrices.

Points are \(n imes n\) orthogonal matrices with det = +1. Geodesic distance uses the bi-invariant metric.

Parameters:

n (int) – Dimension of the rotation group.

distance(R, S)[source]

Geodesic distance: d(R, S) = ||logm(R^T S)||_F / sqrt(2).

distance_squared(R, S)[source]

Squared geodesic distance.

log_map(R, S)[source]

Logarithmic map: Log_R(S) = R @ logm(R^T @ S).

Maps point S on the manifold to a tangent vector at R.

exp_map(R, V)[source]

Exponential map: Exp_R(V) = R @ expm(R^T @ V).

Maps tangent vector V at R back to the manifold.

random_point(key)[source]

Generate a random rotation matrix via QR decomposition.

Flips only the last column to ensure det = +1, preserving the Haar-uniform distribution on SO(n).

belongs(R)[source]

Check if R is in SO(n): \(R^T R \approx I\) and \(\det(R) \approx +1\).

project(R)[source]

Project to nearest rotation matrix via polar decomposition.

injectivity_radius(R)[source]

Injectivity radius of SO(n) is \(\pi\).

class prosemble.core.manifolds.SPD(n)[source]

Manifold of \(n \times n\) symmetric positive definite matrices.

Uses the affine-invariant Riemannian metric.

Parameters:

n (int) – Matrix dimension.

distance(A, B)[source]

Geodesic distance: d(A, B) = ||logm(A^{-1/2} B A^{-1/2})||_F.

distance_squared(A, B)[source]

Squared geodesic distance.

log_map(A, B)[source]

Log map: Log_A(B) = A^{1/2} logm(A^{-1/2} B A^{-1/2}) A^{1/2}.

exp_map(A, V)[source]

Exp map: Exp_A(V) = A^{1/2} expm(A^{-1/2} V A^{-1/2}) A^{1/2}.

random_point(key)[source]

Generate random SPD matrix: \(A = L L^T + \epsilon I\).

belongs(A)[source]

Check if A is SPD: symmetric and all eigenvalues > 0.

project(A)[source]

Project to nearest SPD: symmetrize and clamp eigenvalues.

injectivity_radius(A)[source]

SPD manifold has infinite injectivity radius.

class prosemble.core.manifolds.Grassmannian(n, k)[source]

Grassmannian manifold Gr(n, k): k-dimensional subspaces of R^n.

Points are represented as orthonormal bases Q of shape (n, k) with Q^T Q = I_k.

Parameters:
  • n (int) – Ambient dimension.

  • k (int) – Subspace dimension.

distance(Q1, Q2)[source]

Geodesic distance via principal angles.

d(Q1, Q2) = ||theta||_2 where theta = arccos(svd(Q1^T Q2)).

distance_squared(Q1, Q2)[source]

Squared geodesic distance.

log_map(Q1, Q2)[source]

Logarithmic map on the Grassmannian.

Computes the tangent vector at Q1 pointing toward Q2 using aligned principal angle decomposition (Edelman et al. 1998).

The key insight is to align both subspaces via the SVD of Q1^T Q2, ensuring the principal angles and directions correspond.

exp_map(Q, V)[source]

Exponential map on the Grassmannian.

Maps tangent vector V at Q back to the manifold.

random_point(key)[source]

Generate a random point on Gr(n, k) via QR decomposition.

belongs(Q)[source]

Check if Q represents a point on Gr(n, k): \(Q^T Q \approx I_k\).

project(Q)[source]

Project to nearest orthonormal basis via QR.

injectivity_radius(Q)[source]

Injectivity radius of Gr(n,k) is \(\pi/2\).

class prosemble.core.manifolds.HyperbolicPoincare(d, eps=1e-05)[source]

Poincaré ball model of hyperbolic space \(\mathbb{H}^d\).

Points are vectors in \(\mathbb{B}^d = \{x \in \mathbb{R}^d : \|x\| < 1\}\). The Riemannian metric is \(g_x = \lambda_x^2 g_E\) where \(\lambda_x = 2/(1-\|x\|^2)\) is the conformal factor.

This manifold is the natural geometry for hierarchical and tree-structured data — it can embed trees with arbitrarily low distortion (Sarkar 2011).

Parameters:
  • d (int) – Dimension of the hyperbolic space.

  • eps (float) – Numerical safety margin for boundary clamping.

distance(x, y)[source]

Geodesic distance in the Poincaré ball.

\[d(x, y) = \operatorname{arcosh}\!\left(1 + \frac{2\|x - y\|^2} {(1 - \|x\|^2)(1 - \|y\|^2)}\right)\]
distance_squared(x, y)[source]

Squared geodesic distance.

log_map(x, y)[source]

Logarithmic map: tangent vector at x pointing toward y.

\[\text{Log}_x(y) = \frac{2}{\lambda_x} \operatorname{arctanh}(\|-x \oplus y\|) \frac{-x \oplus y}{\|-x \oplus y\|}\]
exp_map(x, v)[source]

Exponential map: move from x along tangent vector v.

\[\text{Exp}_x(v) = x \oplus \left( \tanh\!\left(\frac{\lambda_x \|v\|}{2}\right) \frac{v}{\|v\|}\right)\]
random_point(key)[source]

Generate a random point uniformly in the Poincaré ball.

Uses the radial transform: sample direction uniformly on S^{d-1}, then radius r ~ Beta(1, d) to get uniform distribution in B^d, scaled to stay safely inside the ball.

belongs(x)[source]

Check if x is in the Poincaré ball: \(\|x\| < 1\).

project(x)[source]

Project to the interior of the Poincaré ball.

Clamps the norm to be strictly less than 1.

injectivity_radius(x)[source]

Hyperbolic space has infinite injectivity radius.

Protocols

Structural typing protocols for prosemble interfaces.

Defines typing.Protocol contracts for duck-typed interfaces used across the library. These enable static type checking (mypy / pyright) and IDE auto-completion without requiring inheritance.

Note

Runtime-checkable protocols (Manifold, CallbackLike) support isinstance() checks. Type aliases (DistanceMatrixFn, etc.) are for annotation only.

class prosemble.core.protocols.Manifold(*args, **kwargs)[source]

Protocol for Riemannian manifold implementations.

Any object exposing the methods below can be used wherever a manifold is expected (e.g. RiemannianNeuralGas, RiemannianSRNG).

The concrete implementations SO, SPD, and Grassmannian all satisfy this protocol structurally — no explicit subclassing is required.

property point_shape: Tuple[int, ...]

Shape of a single point on the manifold.

distance(p, q)[source]

Geodesic distance between two points.

Parameters:
  • p (arrays of shape point_shape)

  • q (arrays of shape point_shape)

Return type:

scalar

distance_squared(p, q)[source]

Squared geodesic distance between two points.

Parameters:
  • p (arrays of shape point_shape)

  • q (arrays of shape point_shape)

Return type:

scalar

log_map(base, target)[source]

Logarithmic map: tangent vector at base pointing toward target.

Parameters:
  • base (array of shape point_shape) – Base point on the manifold.

  • target (array of shape point_shape) – Target point on the manifold.

Returns:

tangent – Tangent vector in \(T_{\text{base}} M\).

Return type:

array of shape point_shape

exp_map(base, tangent)[source]

Exponential map: move along tangent from base back to the manifold.

Parameters:
  • base (array of shape point_shape)

  • tangent (array of shape point_shape)

Returns:

point

Return type:

array of shape point_shape

random_point(key)[source]

Sample a random point on the manifold.

Parameters:

key (JAX PRNG key)

Returns:

point

Return type:

array of shape point_shape

belongs(point)[source]

Check whether point lies on the manifold.

Parameters:

point (array of shape point_shape)

Return type:

bool or bool-valued array

project(point)[source]

Project an off-manifold point to the nearest point on the manifold.

Parameters:

point (array of shape point_shape)

Returns:

projected

Return type:

array of shape point_shape

injectivity_radius(point)[source]

Injectivity radius at point.

The maximum geodesic distance for which the logarithmic map is injective.

Parameters:

point (array of shape point_shape)

Returns:

radius

Return type:

float or scalar array

class prosemble.core.protocols.CallbackLike(*args, **kwargs)[source]

Protocol for training callbacks.

Any object with the three hook methods below can be passed in the callbacks list of a model’s constructor. The existing Callback base class already satisfies this protocol.

on_fit_start(model, X)[source]

Called once before training begins.

Parameters:
Return type:

None

on_iteration_end(model, info)[source]

Called after each training iteration / epoch.

Parameters:
Return type:

None

on_fit_end(model, info)[source]

Called once after training ends.

Parameters:
Return type:

None

prosemble.core.protocols.DistanceMatrixFn

(X, Y) -> distances. X has shape (n_samples, n_features), Y has shape (n_prototypes, n_features), result has shape (n_samples, n_prototypes).

Type:

Distance-matrix function

alias of Callable[[Array, Array], Array]

prosemble.core.protocols.DistancePairwiseFn

(x, y) -> scalar.

Type:

Pairwise distance function

alias of Callable[[Array, Array], Array]

prosemble.core.protocols.SupervisedInitFn

Supervised prototype initializer: (X, y, n_per_class, key) -> (prototypes, prototype_labels).

alias of Callable[[Array, Array, int, Array], Tuple[Array, Array]]

prosemble.core.protocols.UnsupervisedInitFn

Unsupervised prototype initializer: (X, n_prototypes, key) -> prototypes.

alias of Callable[[Array, int, Array], Array]

Distributed Training

Distributed training utilities for multi-device data parallelism.

Provides functions for sharding data across devices and replicating model parameters, enabling data-parallel training on multi-GPU/TPU setups.

prosemble.core.distributed.create_mesh(devices=None)[source]

Create a 1D device mesh for data parallelism.

Parameters:

devices (list of jax.Device or None) – Devices to use. If None, returns None (single-device mode).

Return type:

Mesh or None

prosemble.core.distributed.shard_data(X, y, mesh)[source]

Shard data arrays along the batch dimension across devices.

Parameters:
  • X (jnp.ndarray of shape (n_samples, ...)) – Input data.

  • y (jnp.ndarray of shape (n_samples,)) – Labels.

  • mesh (Mesh) – Device mesh from create_mesh().

Returns:

X_sharded, y_sharded

Return type:

tuple of sharded arrays

prosemble.core.distributed.replicate_params(params, mesh)[source]

Replicate params across all devices (no partitioning).

Parameters:
  • params (dict (pytree)) – Model parameters.

  • mesh (Mesh) – Device mesh from create_mesh().

Return type:

params replicated across mesh

prosemble.core.distributed.replicate_opt_state(opt_state, mesh)[source]

Replicate optimizer state across all devices.

Parameters:
  • opt_state (pytree) – Optax optimizer state.

  • mesh (Mesh) – Device mesh from create_mesh().

Return type:

opt_state replicated across mesh

prosemble.core.distributed.unshard_params(params)[source]

Bring params back to a single device.

Used after training to store results as plain arrays for predict/export operations.

Parameters:

params (pytree) – Potentially sharded parameters.

Return type:

params on default device

Data Utilities

Data loading and batching utilities for prosemble.

Provides composable primitives for mini-batch iteration, suitable for custom training loops. The built-in fit() methods already handle batching internally — these utilities are for advanced users who need explicit control over data iteration.

prosemble.core.data.shuffle_arrays(key, *arrays)[source]

Shuffle multiple arrays with the same random permutation.

Parameters:
  • key (JAX PRNG key) – Random key for generating the permutation.

  • *arrays (jnp.ndarray) – Arrays to shuffle. All must have the same length along axis 0.

Returns:

Shuffled arrays in the same order as the inputs.

Return type:

tuple of jnp.ndarray

Examples

>>> key = jax.random.PRNGKey(0)
>>> X = jnp.arange(12).reshape(4, 3)
>>> y = jnp.array([0, 1, 0, 1])
>>> X_s, y_s = shuffle_arrays(key, X, y)
prosemble.core.data.padded_batches(X, y=None, batch_size=32, key=None)[source]

Split data into static-shaped batches, padding the last batch.

Returns arrays of shape (n_batches, batch_size, ...) suitable for jax.lax.scan. If the data length is not divisible by batch_size, the last batch is padded by repeating initial samples.

Parameters:
  • X (jnp.ndarray of shape (n_samples, ...)) – Feature array.

  • y (jnp.ndarray of shape (n_samples,), optional) – Label array.

  • batch_size (int) – Number of samples per batch.

  • key (JAX PRNG key, optional) – If provided, data is shuffled before batching.

Returns:

  • X_batches (jnp.ndarray of shape (n_batches, batch_size, …))

  • y_batches (jnp.ndarray of shape (n_batches, batch_size) or None)

Return type:

tuple[Array, Array | None]

Examples

>>> X = jnp.ones((10, 3))
>>> y = jnp.arange(10)
>>> X_b, y_b = padded_batches(X, y, batch_size=4)
>>> X_b.shape  # (3, 4, 3) — 10 samples padded to 12
(3, 4, 3)
prosemble.core.data.batched_iterator(X, y=None, batch_size=32, shuffle=True, key=None, drop_last=False)[source]

Yield mini-batches from data arrays.

For use in custom Python training loops.

Parameters:
  • X (jnp.ndarray of shape (n_samples, ...)) – Feature array.

  • y (jnp.ndarray of shape (n_samples,), optional) – Label array.

  • batch_size (int) – Number of samples per batch.

  • shuffle (bool) – If True and key is provided, shuffle before iterating.

  • key (JAX PRNG key, optional) – Required if shuffle=True.

  • drop_last (bool) – If True, drop the last batch when it is smaller than batch_size. If False (default), the last batch may be smaller.

Yields:
  • X_batch (jnp.ndarray of shape (batch_size, …) or smaller)

  • y_batch (jnp.ndarray of shape (batch_size,) or None)

Examples

>>> key = jax.random.PRNGKey(0)
>>> X = jnp.ones((10, 3))
>>> for X_b, y_b in batched_iterator(X, batch_size=4, key=key):
...     print(X_b.shape)
(4, 3)
(4, 3)
(2, 3)

Custom Optimizers

Custom optax-compatible optimizers for prototype-based learning.

Provides specialized gradient transformations designed for the geometry and parameter structure of LVQ models:

  • per_group_clip: Per-parameter-group gradient norm clipping.

  • hypergradient_descent: Adaptive per-parameter learning rates via gradient correlation (Baydin et al. 2017).

  • riemannian_nesterov: Nesterov accelerated gradient with manifold-aware momentum (parallel transport on Riemannian manifolds).

All transformations follow the optax GradientTransformation interface and can be composed via optax.chain() or passed directly to any model’s optimizer parameter.

References

class prosemble.core.optimizers.PerGroupClipState[source]

State for per-group gradient clipping (stateless).

prosemble.core.optimizers.per_group_clip(max_norms)[source]

Clip gradient norms independently per parameter group.

Different parameter types (prototypes, omega matrices, relevances, sigmas) have different natural scales. A single global clip either under-constrains large parameters or over-constrains small ones. This transformation clips each group independently.

Parameters:

max_norms (dict) – Mapping from parameter key name to maximum gradient norm. Keys not present in this dict are left unclipped. Example: {'prototypes': 1.0, 'omega': 0.5, 'sigmas': 0.1}

Returns:

Composable gradient transformation.

Return type:

optax.GradientTransformation

Examples

>>> import optax
>>> from prosemble.core.optimizers import per_group_clip
>>> optimizer = optax.chain(
...     per_group_clip({'prototypes': 1.0, 'omega': 0.5, 'sigmas': 0.1}),
...     optax.adam(0.01),
... )
class prosemble.core.optimizers.HypergradientState(learning_rates, prev_grads, base_opt_state)[source]

State for hypergradient descent optimizer.

Parameters:
learning_rates: dict

Alias for field number 0

prev_grads: dict

Alias for field number 1

base_opt_state: object

Alias for field number 2

prosemble.core.optimizers.hypergradient_descent(init_lr=0.01, hyper_lr=0.0001, inner_optimizer='sgd', min_lr=1e-06, max_lr=1.0)[source]

Adaptive per-parameter learning rates via hypergradient descent.

If consecutive gradients point in the same direction (positive dot product), increase the learning rate. If they oscillate (negative dot product), decrease it. This allows each parameter group to converge at its own optimal rate.

The update rule for learning rate eta_k at step t:

\[\eta_k^{t+1} = \text{clip}\left( \eta_k^t - \beta \cdot \langle g_k^t, g_k^{t-1} \rangle \right)\]
Parameters:
  • init_lr (float) – Initial learning rate for all parameter groups. Default: 0.01.

  • hyper_lr (float) – Learning rate for the learning rate update (meta-learning rate). Default: 1e-4.

  • inner_optimizer (str) – Base optimizer to use (‘sgd’ applies raw scaled gradients). Default: ‘sgd’.

  • min_lr (float) – Minimum allowed learning rate. Default: 1e-6.

  • max_lr (float) – Maximum allowed learning rate. Default: 1.0.

Return type:

optax.GradientTransformation

References

Examples

>>> from prosemble.core.optimizers import hypergradient_descent
>>> optimizer = hypergradient_descent(init_lr=0.01, hyper_lr=1e-4)
class prosemble.core.optimizers.RiemannianNesterovState(velocity, step)[source]

State for Riemannian Nesterov momentum.

Parameters:
velocity: dict

Alias for field number 0

step: Array

Alias for field number 1

prosemble.core.optimizers.riemannian_nesterov(learning_rate=0.01, momentum=0.9)[source]

Nesterov accelerated gradient adapted for prototype-based models.

Implements Nesterov momentum in Euclidean parameter space. For Riemannian models, the prototypes are stored in flattened form and the manifold projection is handled by _post_update().

The momentum buffer provides O(1/t^2) convergence rate versus O(1/t) for vanilla gradient descent on convex objectives.

Update rule:

\[v_{t+1} = \mu \cdot v_t + g_t \theta_{t+1} = \theta_t - \eta \cdot (\mu \cdot v_{t+1} + g_t)\]

This is the Nesterov variant where the lookahead is incorporated into the update (Sutskever et al. 2013 reformulation).

Parameters:
  • learning_rate (float) – Step size. Default: 0.01.

  • momentum (float) – Momentum coefficient (0 < mu < 1). Higher values give more momentum. Default: 0.9.

Return type:

optax.GradientTransformation

Notes

For Riemannian models where prototypes live on manifolds, the manifold retraction (projection back to manifold) is handled by the model’s _post_update() method. This optimizer provides the accelerated gradient direction; the model ensures the result stays on the manifold.

For true Riemannian Nesterov (with parallel transport), use this optimizer with Riemannian models that implement _post_update with manifold projection.

Examples

>>> from prosemble.core.optimizers import riemannian_nesterov
>>> optimizer = riemannian_nesterov(learning_rate=0.01, momentum=0.9)

Regularization

Regularization techniques for prototype-based models.

Provides reusable loss-level regularization terms that can be composed with any model’s _compute_loss method:

  • prototype_diversity_loss: DPP-inspired repulsion between same-class prototypes.

  • sparse_relevance_proximal: Elastic net proximal step for relevance vectors (soft-thresholding).

References

prosemble.core.regularization.prototype_diversity_loss(prototypes, proto_labels, sigma_div=1.0)[source]

Compute DPP-inspired diversity regularization for same-class prototypes.

Encourages multiple prototypes per class to spread out rather than collapsing to the same point. Uses the log-determinant of the RBF kernel matrix between same-class prototypes.

\[L_{\text{diversity}} = -\sum_c \log\det(K_c + \epsilon I)\]

where \(K_c[i,j] = \exp(-\|w_i - w_j\|^2 / (2\sigma^2))\).

The determinant is maximized when prototypes are spread out (DPP theory). When two prototypes collapse, det -> 0, so -log(det) -> inf.

Parameters:
  • prototypes (array of shape (n_prototypes, n_features)) – Prototype positions.

  • proto_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • sigma_div (float) – Bandwidth for the RBF kernel. Controls the scale at which diversity is measured. Default: 1.0.

Returns:

loss – Diversity penalty (lower is more diverse).

Return type:

scalar

Notes

Only active when n_prototypes_per_class > 1. For single-prototype classes, the contribution is zero (log(det(1x1 identity)) = 0).

prosemble.core.regularization.prototype_diversity_loss_vectorized(prototypes, proto_labels, sigma_div=1.0, n_classes=None, max_protos_per_class=None)[source]

JIT-friendly vectorized diversity loss without Python loops.

For use within lax.scan training (use_scan=True). Requires static shape information.

Parameters:
  • prototypes (array of shape (n_prototypes, n_features))

  • proto_labels (array of shape (n_prototypes,))

  • sigma_div (float) – RBF bandwidth. Default: 1.0.

  • n_classes (int) – Number of classes (must be known at compile time).

  • max_protos_per_class (int) – Maximum prototypes in any class (must be known at compile time).

Returns:

loss

Return type:

scalar

prosemble.core.regularization.sparse_relevance_proximal(relevances, l1_weight, lr=1.0)[source]

Apply proximal operator for L1 regularization (soft-thresholding).

Enforces true sparsity in relevance vectors by applying the proximal mapping of the L1 norm after the gradient step:

\[\text{prox}_{\alpha\|\cdot\|_1}(x)_j = \text{sign}(x_j) \cdot \max(|x_j| - \alpha, 0)\]

This gives genuinely sparse feature selection with LASSO consistency guarantees.

Parameters:
  • relevances (array of shape (n_features,) or (n_prototypes, n_features)) – Relevance logits (before softmax normalization).

  • l1_weight (float) – L1 penalty strength. Larger values give sparser solutions.

  • lr (float) – Learning rate (used to scale the threshold: threshold = l1_weight * lr). Default: 1.0.

Returns:

sparse_relevances – Thresholded relevances with exact zeros.

Return type:

array, same shape as input

prosemble.core.regularization.elastic_net_proximal(relevances, l1_weight, l2_weight, lr=1.0)[source]

Apply elastic net proximal operator (L1 + L2 regularization).

Combines soft-thresholding (L1 sparsity) with L2 shrinkage:

\[\text{prox}(x)_j = \frac{\text{sign}(x_j) \max(|x_j| - \alpha, 0)} {1 + \beta}\]

where alpha = l1_weight * lr, beta = l2_weight * lr.

Parameters:
  • relevances (array) – Relevance logits.

  • l1_weight (float) – L1 penalty strength (sparsity).

  • l2_weight (float) – L2 penalty strength (shrinkage).

  • lr (float) – Learning rate scaling. Default: 1.0.

Returns:

regularized_relevances – Thresholded and shrunk relevances.

Return type:

array

Curriculum Learning

Curriculum learning (self-paced learning) for prototype-based models.

Implements the self-paced learning framework (Kumar et al. 2010) adapted for LVQ models. Samples are presented in order of difficulty — easy samples first, hard samples later.

For LVQ, sample difficulty is measured by the absolute value of the mu-ratio: |mu(x)| close to 0 means the sample is on the decision boundary (hard), while |mu(x)| >> 0 means it’s far from the boundary (easy).

Usage

Curriculum learning is applied via sample masking in the loss function. At each training step, a difficulty threshold lambda_t determines which samples participate. The threshold increases over training to gradually include harder samples.

References

prosemble.core.curriculum.curriculum_weights(per_sample_loss, threshold, mode='hard')[source]

Compute per-sample curriculum weights based on loss magnitude.

Parameters:
  • per_sample_loss (array of shape (n_samples,)) – Per-sample loss values (e.g., from GLVQ mu-ratio).

  • threshold (float) – Difficulty threshold lambda_t. Samples with loss below this threshold are included. Increases over training.

  • mode (str, {'hard', 'soft', 'linear'}) –

    Weight assignment mode: - ‘hard’: binary weights (0 or 1). Sample is either in or out. - ‘soft’: smooth exponential weighting.

    w_i = exp(-loss_i / threshold) if loss_i > threshold, else 1.

    • ’linear’: linear ramp-down. w_i = max(0, 1 - (loss_i - threshold) / threshold).

    Default: ‘hard’.

Returns:

weights – Per-sample weights in [0, 1]. Weights sum is used to normalize the loss.

Return type:

array of shape (n_samples,)

prosemble.core.curriculum.curriculum_threshold(iteration, max_iter, init_threshold=0.3, final_threshold=None, schedule='linear')[source]

Compute the difficulty threshold at a given training iteration.

The threshold increases over training: initially only easy samples are included, and harder samples are gradually added.

Parameters:
  • iteration (int or array) – Current training iteration.

  • max_iter (int) – Maximum number of iterations.

  • init_threshold (float) – Initial threshold (fraction of max loss to include). At the start, only samples with loss < init_threshold * max_loss are included. Default: 0.3.

  • final_threshold (float, optional) – Final threshold. Default: None (uses a large value to include all samples by end of training).

  • schedule (str, {'linear', 'exponential', 'cosine'}) – How the threshold grows over training: - ‘linear’: lambda_t = init + (final - init) * t / T - ‘exponential’: lambda_t = init * (final / init) ^ (t / T) - ‘cosine’: cosine annealing from init to final. Default: ‘linear’.

Returns:

threshold – Difficulty threshold at the current iteration.

Return type:

float

prosemble.core.curriculum.apply_curriculum_to_loss(per_sample_losses, iteration, max_iter, init_percentile=0.3, schedule='linear', mode='soft')[source]

Full curriculum pipeline: compute threshold and weighted loss.

Convenience function that combines threshold scheduling and weight computation. Use inside _compute_loss to add curriculum learning to any model.

Parameters:
  • per_sample_losses (array of shape (n_samples,)) – Per-sample loss values.

  • iteration (int or array) – Current training step.

  • max_iter (int) – Total training steps.

  • init_percentile (float) – Initial fraction of samples to include (by loss quantile). Default: 0.3 (include the easiest 30% of samples initially).

  • schedule (str) – Threshold growth schedule. Default: ‘linear’.

  • mode (str) – Weighting mode. Default: ‘soft’.

Returns:

weighted_loss – Curriculum-weighted mean loss.

Return type:

scalar

Reject Option

Reject option with calibrated uncertainty for prototype classifiers.

Implements Chow’s optimal reject rule (1970) adapted for LVQ models. When the confidence in a prediction is below a threshold, the model abstains from classifying the sample (returns -1 or a reject label).

For LVQ, the confidence measure is the negative mu-ratio:

\[\text{confidence}(x) = \frac{d^-(x) - d^+(x)}{d^-(x) + d^+(x)}\]

This is the natural confidence measure from the GLVQ loss. Values near 0 mean the sample is on the decision boundary (uncertain), values near 1 mean high confidence.

References

class prosemble.core.reject.RejectOptionMixin[source]

Mixin providing reject option for any prototype-based classifier.

Adds predict_with_rejection() and confidence() methods to any supervised model that has fitted prototypes and prototype labels.

The reject decision is based on the relative margin:

\[\text{confidence}(x) = \frac{d^-(x) - d^+(x)}{d^-(x) + d^+(x)}\]

If confidence < threshold, the sample is rejected (label = -1).

confidence(X) array[source]

Compute confidence scores for each sample.

predict_with_rejection(X, threshold) array[source]

Predict with rejection. Returns -1 for rejected samples.

rejection_rate(X, threshold) float[source]

Compute the fraction of samples that would be rejected.

optimal_threshold(X, y, cost_reject, cost_error) float[source]

Find the optimal rejection threshold minimizing total risk.

confidence(X)[source]

Compute confidence scores based on the relative margin.

The confidence is the negative of the GLVQ mu-ratio: confidence(x) = (d_minus - d_plus) / (d_minus + d_plus)

Values range from -1 (maximally wrong) to +1 (maximally confident). Values near 0 indicate the sample is on the decision boundary.

Parameters:

X (array-like of shape (n_samples, n_features)) – Input data.

Returns:

scores – Confidence scores in [-1, 1].

Return type:

array of shape (n_samples,)

predict_with_rejection(X, threshold=0.0)[source]

Predict class labels with rejection option.

Samples with confidence below the threshold are assigned the reject label (-1) instead of a class prediction.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Input data.

  • threshold (float) – Rejection threshold. Samples with confidence < threshold are rejected. Default: 0.0 (reject uncertain samples). Reasonable range: [0.0, 0.5].

Returns:

labels – Predicted labels. Rejected samples have label -1.

Return type:

array of shape (n_samples,)

rejection_rate(X, threshold=0.0)[source]

Compute the fraction of samples that would be rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Input data.

  • threshold (float) – Rejection threshold. Default: 0.0.

Returns:

rate – Fraction of rejected samples in [0, 1].

Return type:

float

accuracy_coverage_curve(X, y, n_thresholds=50)[source]

Compute accuracy-coverage curve (reject curve).

Shows the trade-off between classification accuracy and coverage (1 - rejection_rate) as the threshold varies.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Input data.

  • y (array-like of shape (n_samples,)) – True labels.

  • n_thresholds (int) – Number of threshold values to evaluate. Default: 50.

Returns:

  • thresholds (array of shape (n_thresholds,)) – Threshold values evaluated.

  • accuracies (array of shape (n_thresholds,)) – Accuracy on non-rejected samples at each threshold.

  • coverages (array of shape (n_thresholds,)) – Fraction of non-rejected samples at each threshold.

optimal_threshold(X, y, cost_reject=0.5, cost_error=1.0, n_thresholds=100)[source]

Find the optimal rejection threshold minimizing total risk.

Minimizes Chow’s risk:

Risk = P(error|accepted) * coverage + cost_reject * (1 - coverage)

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • y (array-like of shape (n_samples,)) – True labels.

  • cost_reject (float) – Cost of rejecting a sample (relative to cost of error). Default: 0.5 (rejection costs half of a misclassification).

  • cost_error (float) – Cost of a misclassification. Default: 1.0.

  • n_thresholds (int) – Number of thresholds to evaluate. Default: 100.

Returns:

optimal_threshold – The threshold that minimizes total risk.

Return type:

float

Geodesic Interpolation

Geodesic interpolation and boundary visualization for Riemannian models.

Computes geodesic paths between prototypes on Riemannian manifolds for decision boundary visualization and model interpretation.

On curved manifolds, the decision boundary (locus where d(x, w_i) = d(x, w_j)) is not a hyperplane — it’s a curved surface. Geodesic interpolation allows us to: 1. Visualize the path between prototypes along the manifold. 2. Find the approximate decision boundary point along geodesics. 3. Compute prototype midpoints respecting manifold geometry.

References

prosemble.core.geodesic.geodesic_interpolation(manifold, point_a, point_b, n_points=50)[source]

Compute a geodesic path between two points on a manifold.

Uses the exponential map to trace the shortest path (geodesic) between two manifold points:

\[\gamma(t) = \text{Exp}_{w_a}(t \cdot \text{Log}_{w_a}(w_b)), \quad t \in [0, 1]\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance with exp_map and log_map.

  • point_a (array) – Starting point on the manifold.

  • point_b (array) – End point on the manifold.

  • n_points (int) – Number of interpolation points along the geodesic. Default: 50.

Returns:

path – Points along the geodesic from point_a to point_b. Each point lies on the manifold.

Return type:

array of shape (n_points, …)

prosemble.core.geodesic.geodesic_midpoint(manifold, point_a, point_b)[source]

Compute the geodesic midpoint between two manifold points.

The midpoint is at t = 0.5 along the geodesic:

\[m = \text{Exp}_{w_a}(0.5 \cdot \text{Log}_{w_a}(w_b))\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • point_a (array) – First manifold point.

  • point_b (array) – Second manifold point.

Returns:

midpoint – Geodesic midpoint, guaranteed to lie on the manifold.

Return type:

array

prosemble.core.geodesic.decision_boundary_point(manifold, proto_a, proto_b, n_search=100)[source]

Find the decision boundary point along the geodesic between prototypes.

On curved manifolds, the equidistant point (where d(x, w_a) = d(x, w_b)) is not necessarily at t = 0.5. This function searches along the geodesic for the point where distances to both prototypes are equal.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • proto_a (array) – First prototype (on manifold).

  • proto_b (array) – Second prototype (on manifold).

  • n_search (int) – Number of candidate points to evaluate. Default: 100.

Returns:

  • boundary_point (array) – Point on the geodesic where d(point, proto_a) = d(point, proto_b).

  • t_boundary (float) – Parameter value t in [0, 1] where the boundary lies. t = 0.5 for symmetric manifolds, may differ on curved spaces.

prosemble.core.geodesic.prototype_geodesic_distances(manifold, prototypes, proto_labels)[source]

Compute pairwise geodesic distances between prototypes.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • prototypes (array of shape (n_prototypes, ...)) – Prototype points on the manifold (flattened).

  • proto_labels (array of shape (n_prototypes,)) – Class labels for prototypes.

Returns:

distances – Pairwise geodesic distance matrix.

Return type:

array of shape (n_prototypes, n_prototypes)

prosemble.core.geodesic.inter_class_geodesics(manifold, prototypes, proto_labels, n_points=50)[source]

Compute geodesic paths between all inter-class prototype pairs.

Useful for visualizing decision boundaries: the boundary between two classes lies somewhere along the geodesic connecting their closest prototypes.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • prototypes (array of shape (n_prototypes, ...)) – Prototype positions on manifold.

  • proto_labels (array of shape (n_prototypes,)) – Class labels.

  • n_points (int) – Points per geodesic path. Default: 50.

Returns:

geodesics – Each dict contains: - ‘path’: array of shape (n_points, …) — geodesic path - ‘proto_a_idx’: int — index of first prototype - ‘proto_b_idx’: int — index of second prototype - ‘class_a’: int — class of first prototype - ‘class_b’: int — class of second prototype - ‘boundary_t’: float — approximate boundary location

Return type:

list of dict