Core API Reference¶
Core modules provide the building blocks used by all models: distance functions, loss functions, initializers, and utilities.
Distance Functions¶
JAX-based distance functions for Prosemble.
This module provides GPU-accelerated, vectorized distance computations using JAX. All functions are JIT-compiled for maximum performance.
Mathematical Background¶
Distance metrics are fundamental to prototype-based learning algorithms. This implementation focuses on: 1. Batch/matrix operations (no Python loops) 2. GPU compatibility 3. JIT compilation for speed 4. Numerical stability
Author: Prosemble Contributors License: MIT
- prosemble.core.distance.euclidean_distance_matrix(X, Y)[source]¶
Compute pairwise Euclidean distances between rows of X and Y.
Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).
Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)
- Return type:
D
- Complexity:
Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix
Example
>>> X = jnp.array([[0, 0], [1, 1], [2, 2]]) >>> Y = jnp.array([[0, 0], [3, 3]]) >>> D = euclidean_distance_matrix(X, Y) >>> D.shape (3, 2) >>> D[0, 0] # Distance from X[0] to Y[0] 0.0 >>> D[2, 1] # Distance from X[2] to Y[1] 1.414...
Notes
Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives
GPU-compatible: All operations are JAX primitives
JIT-compiled: First call compiles, subsequent calls are fast
- prosemble.core.distance.squared_euclidean_distance_matrix(X, Y)[source]¶
Compute pairwise squared Euclidean distances.
More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.
Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nm)
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D_sq = squared_euclidean_distance_matrix(X, Y) >>> D_sq[1, 1] # Squared distance from [1,1] to [2,2] 2.0
Notes
Preferred over euclidean when squared distances are sufficient
Many algorithms (FCM, PCM) use squared distances directly
More numerically stable than squaring euclidean distances
- prosemble.core.distance.manhattan_distance_matrix(X, Y)[source]¶
Compute pairwise Manhattan (L1) distances.
Formula: \(D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|\)
- Parameters:
- Returns:
Array of shape (n, m) where D[i,j] is Manhattan distance
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nmd) - intermediate broadcasting
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D = manhattan_distance_matrix(X, Y) >>> D[1, 1] # Manhattan distance from [1,1] to [2,2] 2.0
- Implementation:
Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d) Then sums absolute differences along feature dimension.
Notes
Also known as “taxicab” or “city block” distance
More robust to outliers than Euclidean distance
Natural for sparse/binary features
- prosemble.core.distance.lpnorm_distance_matrix(X, Y, p)[source]¶
Compute pairwise L-p norm distances.
Formula: \(D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}\)
- Special Cases:
p = 1: Manhattan distance p = 2: Euclidean distance p = \(\infty\): Chebyshev distance (max absolute difference)
- Parameters:
- Returns:
Array of shape (n, m) where D[i,j] is L-p distance
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nmd)
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [3, 4]]) >>> D = lpnorm_distance_matrix(X, Y, p=2) # Euclidean >>> D = lpnorm_distance_matrix(X, Y, p=1) # Manhattan >>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf) # Chebyshev
Notes
For p=1, use manhattan_distance_matrix for better performance
For p=2, use euclidean_distance_matrix for better performance
For p=inf, computes
max(|x - y|)
- prosemble.core.distance.omega_distance_matrix(X, Y, omega)[source]¶
Compute distances in projected space using projection matrix Omega.
Formula: \(D_{ij} = \|X_i \Omega - Y_j \Omega\|^2\) where \(\Omega\) is a projection matrix that transforms the feature space.
- Parameters:
- Returns:
Array of shape (n, m) with squared distances in projected space
- Return type:
D
- Complexity:
Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk) Space: O(nk + mk + nm)
- Use Cases:
Dimensionality reduction for distance computation
Learning relevance of features (omega learned from data)
Mahalanobis-like distances (when \(\Omega = L\) where \(\Sigma = LL^T\))
Example
>>> X = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> Y = jnp.array([[0, 0, 0], [1, 1, 1]]) >>> omega = jnp.array([[1, 0], [0, 1], [0, 0]]) # Project to first 2 dims >>> D = omega_distance_matrix(X, Y, omega) >>> D.shape (2, 2)
Notes
When omega is identity, reduces to squared Euclidean distance
When omega is learned, enables adaptive distance metrics
Used in GLVQ (Generalized Learning Vector Quantization)
- prosemble.core.distance.lomega_distance_matrix(X, Y, omegas)[source]¶
Compute distances using multiple projection matrices (Local Omega).
Formula: \(D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2\) where \(\Omega_p\) are multiple projection matrices (one per prototype or cluster).
- Parameters:
X (Array | ndarray | bool | number) – Array of shape (n, d) - data points
Y (Array | ndarray | bool | number) – Array of shape (m, d) - prototypes/centroids
omegas (Array | ndarray | bool | number) – Array of shape (m, d, k) - m projection matrices of size (d, k) Each Y[j] has its own projection matrix omegas[j]
- Returns:
Array of shape (n, m) with aggregated projected distances
- Return type:
D
- Complexity:
Time: O(nmdk) Space: O(nmk)
- Use Cases:
Local relevance learning (each prototype has its own metric)
Adaptive distance metrics in GMLVQ
Cluster-specific feature weighting
Example
>>> n, m, d, k = 10, 3, 5, 2 >>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d)) >>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d)) >>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k)) >>> D = lomega_distance_matrix(X, Y, omegas) >>> D.shape (10, 3)
- Implementation:
Uses einsum for efficient tensor contraction: 1. Project X through each omega: X @ omegas[j] for all j 2. Extract diagonal for Y projections (each Y[j] uses omegas[j]) 3. Compute squared differences and sum
Notes
Generalizes omega_distance to local (per-prototype) metrics
More flexible but computationally expensive
Enables learning which features matter for each cluster
- prosemble.core.distance.tangent_distance_matrix(X, Y, omegas)[source]¶
Compute pairwise localized tangent distances.
Each prototype j has an orthogonal subspace basis \(\Omega_j\) of shape (d, s). The tangent distance projects out the subspace directions:
\[d(x, w_j) = \|(I - \Omega_j \Omega_j^T)(x - w_j)\|^2\]- Parameters:
X (array of shape (n, d)) – Data points.
Y (array of shape (m, d)) – Prototypes.
omegas (array of shape (m, d, s)) – Orthogonal subspace bases per prototype, where s is the subspace dimension.
- Returns:
D – Squared tangent distances.
- Return type:
array of shape (n, m)
Notes
Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent distances in generalized learning vector quantization.
- prosemble.core.distance.wasserstein2_distance_matrix(X, means, log_variances)[source]¶
Compute pairwise squared 2-Wasserstein distances from points to Gaussian prototypes.
Each prototype is a diagonal Gaussian \(\mathcal{N}(\mu_k, \text{diag}(\sigma_k^2))\). Each input point \(x\) is treated as a Dirac delta distribution \(\delta_x\).
The squared 2-Wasserstein distance from a point to a diagonal Gaussian is:
\[W_2^2(\delta_x, \mathcal{N}(\mu_k, \text{diag}(\sigma_k^2))) = \sum_j (x_j - \mu_{kj})^2 + \sum_j \sigma_{kj}^2\]This decomposes into the squared Euclidean distance from the point to the mean, plus the total variance (trace of covariance). Prototypes with smaller variance are effectively “more certain” and attract nearby points more strongly.
- Parameters:
X (array of shape (n, d)) – Data points.
means (array of shape (p, d)) – Prototype mean vectors.
log_variances (array of shape (p, d)) – Log of prototype variances (ensures positivity via
exp).
- Returns:
D – Squared 2-Wasserstein distances.
- Return type:
array of shape (n, p)
References
- prosemble.core.distance.wasserstein2_omega_distance_matrix(X, means, log_variances, omega)[source]¶
Squared 2-Wasserstein distance with global metric adaptation.
Projects data and means through \(\Omega\) before computing the Euclidean component, while variances contribute directly:
\[W_2^2(x, k) = \|\Omega(x - \mu_k)\|^2 + \sum_j \sigma_{kj}^2\]- Parameters:
X (array of shape (n, d)) – Data points.
means (array of shape (p, d)) – Prototype mean vectors.
log_variances (array of shape (p, d)) – Log of prototype variances.
omega (array of shape (d, l)) – Global projection matrix.
- Returns:
D – Squared 2-Wasserstein distances in projected space.
- Return type:
array of shape (n, p)
- prosemble.core.distance.wasserstein2_relevance_distance_matrix(X, means, log_variances, relevances)[source]¶
Squared 2-Wasserstein distance with feature relevance weighting.
Applies per-feature relevance weights \(\lambda_j\) to the Euclidean component:
\[W_2^2(x, k) = \sum_j \lambda_j (x_j - \mu_{kj})^2 + \sum_j \sigma_{kj}^2\]where \(\lambda_j = \text{softmax}(r)_j\) ensures non-negative weights that sum to 1.
- Parameters:
X (array of shape (n, d)) – Data points.
means (array of shape (p, d)) – Prototype mean vectors.
log_variances (array of shape (p, d)) – Log of prototype variances.
relevances (array of shape (d,)) – Raw relevance logits (softmax applied internally).
- Returns:
D – Relevance-weighted squared 2-Wasserstein distances.
- Return type:
array of shape (n, p)
- prosemble.core.distance.gaussian_kernel_matrix(X, Y, sigma)[source]¶
Compute Gaussian (RBF) kernel matrix.
Formula: \(K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))\)
The Gaussian kernel maps data to infinite-dimensional Hilbert space, enabling non-linear clustering and classification.
- Args:
X: Array of shape (n, d) Y: Array of shape (m, d) sigma: Bandwidth parameter (\(\sigma > 0\))
- Returns:
- K: Array of shape (n, m) where \(K_{ij} \in [0, 1]\).
\(K_{ij} = 1\) when \(X_i = Y_j\); \(K_{ij} o 0\) as \(\|X_i - Y_j\| o \infty\)
- Complexity:
Time: O(nmd) Space: O(nm)
- Properties:
K is positive semi-definite (valid kernel)
K is symmetric if X = Y
K[i,i] = 1 (self-similarity)
- Example:
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> K = gaussian_kernel_matrix(X, Y, sigma=1.0) >>> K[0, 0] # Self-similarity 1.0 >>> K[0, 1] < K[0, 0] # Decreases with distance True
- Kernel Trick:
For feature map \(\phi\) mapping to infinite-dimensional Hilbert space, :math:`K(x, y) = langlephi(x), phi(y)
- angle`. Kernel distance:
\(\|\phi(x) - \phi(y)\|^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y)\) for normalized kernel.
- Use Cases:
Kernel Fuzzy C-Means (KFCM)
Kernel Possibilistic C-Means (KPCM)
Support Vector Machines (SVM)
Gaussian Processes
- Hyperparameter Tuning:
Small \(\sigma\): Tight clusters, high sensitivity to noise
Large \(\sigma\): Smooth clusters, may underfit
Rule of thumb: \(\sigma pprox ext{median}( ext{pairwise\_distances}) / \sqrt{2 \cdot n_ ext{clusters}}\)
- Notes:
sigma is bandwidth, NOT variance (variance = \(\sigma^2\))
For numerical stability, we use maximum() to ensure non-negative
JIT-compiled for GPU acceleration
- prosemble.core.distance.polynomial_kernel_matrix(X, Y, degree=3, coef0=1.0)[source]¶
Compute polynomial kernel matrix.
Formula: \(K_{ij} = (X_i^T Y_j + c)^d\) where d is degree and c is coef0.
- Parameters:
- Returns:
Array of shape (n, m) with polynomial kernel values
- Return type:
K
Example
>>> X = jnp.array([[1, 2], [3, 4]]) >>> Y = jnp.array([[1, 0], [0, 1]]) >>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0)
Notes
degree=1, coef0=0: Linear kernel (dot product)
Higher degree: More complex decision boundaries
coef0: Influences importance of lower vs higher order terms
- prosemble.core.distance.euclidean_distance(x, y)[source]¶
Euclidean distance between two vectors.
- Parameters:
- Returns:
Scalar distance
- Return type:
Example
>>> x = jnp.array([0, 0, 0]) >>> y = jnp.array([1, 1, 1]) >>> d = euclidean_distance(x, y) >>> d Array(1.732..., dtype=float32)
- prosemble.core.distance.squared_euclidean_distance(x, y)[source]¶
Squared Euclidean distance between two vectors.
- prosemble.core.distance.manhattan_distance(x, y)[source]¶
Manhattan (L1) distance between two vectors.
- prosemble.core.distance.omega_distance(x, y, omega)[source]¶
Omega (projection-based) distance between two vectors.
Computes \(\| ext{diff} \cdot \Omega\|^2\) where :math:` ext{diff} = x - y`.
- Parameters:
- Returns:
Scalar squared distance in projected space
- Return type:
- prosemble.core.distance.lomega_distance(X, Y, omegas)[source]¶
Local omega distance with per-prototype projection matrices.
- Parameters:
- Returns:
Distance matrix of shape (n, m)
- Return type:
- prosemble.core.distance.estimate_sigma(X, percentile=50.0)[source]¶
Estimate sigma for Gaussian kernel using pairwise distances.
Strategy: Use median (or other percentile) of pairwise distances.
- Parameters:
- Returns:
Estimated bandwidth parameter
- Return type:
sigma
Example
>>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10)) >>> sigma = estimate_sigma(X, percentile=50)
Notes
Heuristic: \(\sigma = ext{median\_distance} / \sqrt{2 \cdot n_ ext{clusters}}\)
For large datasets, use subsample to avoid O(n²) computation
- prosemble.core.distance.safe_divide(numerator, denominator, epsilon=1e-10)[source]¶
Safe division avoiding division by zero.
- Parameters:
- Returns:
numerator / (denominator + epsilon)
- Return type:
Example
>>> x = jnp.array([1.0, 2.0, 3.0]) >>> y = jnp.array([2.0, 0.0, 1.0]) >>> safe_divide(x, y) Array([0.5, 2e+09, 3.0], dtype=float32) # Avoids inf
- prosemble.core.distance.batch_squared_euclidean(X, Y)¶
Compute pairwise squared Euclidean distances.
More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.
Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nm)
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D_sq = squared_euclidean_distance_matrix(X, Y) >>> D_sq[1, 1] # Squared distance from [1,1] to [2,2] 2.0
Notes
Preferred over euclidean when squared distances are sufficient
Many algorithms (FCM, PCM) use squared distances directly
More numerically stable than squaring euclidean distances
- prosemble.core.distance.batch_euclidean(X, Y)¶
Compute pairwise Euclidean distances between rows of X and Y.
Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).
Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)
- Return type:
D
- Complexity:
Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix
Example
>>> X = jnp.array([[0, 0], [1, 1], [2, 2]]) >>> Y = jnp.array([[0, 0], [3, 3]]) >>> D = euclidean_distance_matrix(X, Y) >>> D.shape (3, 2) >>> D[0, 0] # Distance from X[0] to Y[0] 0.0 >>> D[2, 1] # Distance from X[2] to Y[1] 1.414...
Notes
Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives
GPU-compatible: All operations are JAX primitives
JIT-compiled: First call compiles, subsequent calls are fast
Similarities¶
Similarity functions for prototype-based learning.
Similarities are the dual of distances: higher values indicate closer/more similar points.
- prosemble.core.similarities.gaussian_similarity(distances_sq, variance=1.0)[source]¶
Convert squared distances to Gaussian similarities.
\(s(d) = \exp(-d^2 / (2 \cdot ext{variance}))\)
- Parameters:
distances_sq (array) – Squared distances.
variance (float) – Variance (sigma^2) of the Gaussian.
- Returns:
Similarity values in (0, 1].
- Return type:
array
- prosemble.core.similarities.cosine_similarity_matrix(X, Y)[source]¶
Pairwise cosine similarity between rows of X and Y.
:math:`cos(x, y) =
rac{x cdot y}{|x| cdot |y|}`
X : array of shape (n, d) Y : array of shape (m, d)
- array of shape (n, m)
Cosine similarities in [-1, 1].
- prosemble.core.similarities.euclidean_similarity(X, Y, variance=1.0)[source]¶
Pairwise Euclidean similarity (Gaussian of Euclidean distance).
- Parameters:
X (array of shape (n, d))
Y (array of shape (m, d))
variance (float) – Variance of the Gaussian kernel.
- Returns:
Similarity values in (0, 1].
- Return type:
array of shape (n, m)
- prosemble.core.similarities.rank_scaled_gaussian(distances, lambd=1.0)[source]¶
Rank-scaled Gaussian similarity.
Combines distance magnitude with rank ordering: closer prototypes (lower rank) receive a stronger signal, while farther ones are exponentially suppressed.
\[s(d, r) = \exp(-\exp(-r / \lambda) \cdot d)\]where r is the rank of each distance (0 = closest).
- Parameters:
distances (array of shape (n, m)) – Distance matrix (non-negative).
lambd (float) – Rank decay parameter. Larger values give more uniform weighting across ranks; smaller values concentrate on nearest neighbours.
- Returns:
Rank-scaled similarity values in (0, 1].
- Return type:
array of shape (n, m)
Notes
Used in Probabilistic LVQ (PLVQ) as a conditional distribution P(x|prototype).
Kernel Functions¶
JAX-based kernel functions for kernel clustering methods.
This module provides GPU-accelerated kernel computations using JAX.
- prosemble.core.kernel.gaussian_kernel(x, y, sigma)[source]¶
Compute Gaussian (RBF) kernel between two vectors.
\(K(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))\)
- prosemble.core.kernel.batch_gaussian_kernel(X, Y, sigma)[source]¶
Compute Gaussian kernel between two sets of vectors.
- Parameters:
- Returns:
Kernel matrix, shape (n_samples, m_samples) K[i, j] = K(X[i], Y[j])
- Return type:
- prosemble.core.kernel.kernel_distance_squared(X, Y, sigma)[source]¶
Compute squared distance in feature space.
For Gaussian kernel: \(\|\phi(x) - \phi(y)\|^2 = K(x,x) + K(y,y) - 2K(x,y) = 2(1 - K(x,y))\), since \(K(x,x) = 1\).
- Parameters:
- Returns:
Squared distances in feature space, shape (n_samples, m_samples)
- Return type:
- prosemble.core.kernel.kernel_distance(X, Y, sigma)[source]¶
Compute distance in feature space.
- Parameters:
- Returns:
Distances in feature space, shape (n_samples, m_samples)
- Return type:
- prosemble.core.kernel.kernel_distance_squared_per_proto(X, W, sigmas)[source]¶
Squared kernel distance with per-prototype bandwidth.
\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\|x - w_k\|^2}{2\sigma_k^2}\right)\right)\]Each prototype \(w_k\) has its own bandwidth \(\sigma_k\).
- Parameters:
- Returns:
Squared distances in feature space, shape (n_samples, n_prototypes).
- Return type:
References
[1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing.
- prosemble.core.kernel.kernel_distance_squared_relevance(X, W, sigmas, relevances)[source]¶
Squared kernel distance with relevance weighting and per-prototype bandwidth.
\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2}\right)\right)\]- Parameters:
X (Array | ndarray | bool | number) – Data matrix, shape (n_samples, n_features).
W (Array | ndarray | bool | number) – Prototype matrix, shape (n_prototypes, n_features).
sigmas (Array | ndarray | bool | number) – Per-prototype bandwidths, shape (n_prototypes,).
relevances (Array | ndarray | bool | number) – Normalized relevance weights, shape (n_features,).
- Returns:
Squared distances in feature space, shape (n_samples, n_prototypes).
- Return type:
References
[1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing.
- prosemble.core.kernel.exponential_kernel_distance_squared(X, W, omega_hat)[source]¶
Squared distance in exponential kernel feature space.
Uses the exponential kernel \(\kappa_{\exp}(v, w, \hat\Lambda) = \exp(v^T \hat\Lambda w)\) where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).
\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]Note: \(\kappa(v, v) \neq 1\) for the exponential kernel, so the full three-term formula is required (not the 2(1-K) simplification).
- Parameters:
X (Array | ndarray | bool | number) – Data matrix, shape (n_samples, n_features).
W (Array | ndarray | bool | number) – Prototype matrix, shape (n_prototypes, n_features).
omega_hat (Array | ndarray | bool | number) – Transformation matrix, shape (n_features, latent_dim). The kernel matrix is \(\hat\Lambda = \hat\Omega \hat\Omega^T\).
- Returns:
Squared distances in feature space, shape (n_samples, n_prototypes).
- Return type:
References
[1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing.
Loss Functions¶
Loss functions for prototype-based learning.
All loss functions are differentiable by jax.grad and JIT-compatible. They use jnp.where masking (not boolean indexing) for d+/d- extraction.
- prosemble.core.losses.glvq_loss(distances, target_labels, prototype_labels, margin=0.0)[source]¶
Generalized LVQ loss.
:math:`mu_i =
rac{d^+_i - d^-_i}{d^+_i + d^-_i}`
distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) margin : float
Margin added to mu before transfer.
- scalar
Mean loss over samples.
- prosemble.core.losses.glvq_loss_with_transfer(distances, target_labels, prototype_labels, transfer_fn=<PjitFunction of <function identity>>, margin=0.0, beta=10.0)[source]¶
GLVQ loss with configurable transfer function.
:math:` ext{loss} = ext{mean}(f(mu + ext{margin}, eta))`
- prosemble.core.losses.lvq1_loss(distances, target_labels, prototype_labels)[source]¶
LVQ1 loss: d+ when correct, -d- when wrong.
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
- Return type:
scalar
- prosemble.core.losses.lvq21_loss(distances, target_labels, prototype_labels)[source]¶
LVQ2.1 loss: d+ - d- (unnormalized).
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
- Return type:
scalar
- prosemble.core.losses.nllr_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]¶
Negative Log-Likelihood Ratio loss (for SLVQ).
:math:` ext{loss} = -log(P( ext{correct}) / P( ext{wrong}))`
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
sigma (float) – Bandwidth of Gaussian mixture.
- Return type:
scalar
- prosemble.core.losses.rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]¶
Robust Soft LVQ loss (for RSLVQ).
:math:` ext{loss} = -log(P( ext{correct}) / P( ext{all}))`
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
sigma (float)
- Return type:
scalar
- prosemble.core.losses.ng_rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0, gamma=1.0)[source]¶
RSLVQ loss with Neural Gas rank-based neighborhood cooperation.
Combines Gaussian mixture prototype probabilities with NG rank weights to create a neighborhood-cooperative probabilistic assignment.
Gaussian: \(p(k|x) = \exp(-d_k / 2\sigma^2) / \sum_j \exp(-d_j / 2\sigma^2)\)
NG weights: \(h_k = \exp(- ext{rank}_k / \gamma) / \sum_j \exp(- ext{rank}_j / \gamma)\)
Combined: \(w_k = p(k|x) \cdot h_k / \sum_j p(j|x) \cdot h_j\)
The loss is \(-\log(\sum_{k \in ext{correct}} w_k)\).
- Parameters:
distances (array of shape (n, p)) – Squared distances from samples to prototypes.
target_labels (array of shape (n,)) – True class labels for samples.
prototype_labels (array of shape (p,)) – Class labels assigned to prototypes.
sigma (float) – Bandwidth of Gaussian mixture.
gamma (float) – Neural Gas neighborhood range.
- Returns:
Mean negative log-likelihood with NG cooperation.
- Return type:
scalar
- prosemble.core.losses.cross_entropy_lvq_loss(distances, target_labels, prototype_labels, n_classes)[source]¶
Cross-entropy LVQ loss (for CELVQ).
Min distances per class via masking
Negate to get logits (closer = higher)
Cross-entropy against true labels
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
n_classes (int)
- Return type:
scalar
- prosemble.core.losses.margin_loss(y_pred, y_true_one_hot, margin=0.3)[source]¶
Margin loss for CBC.
:math:` ext{loss} = ext{ReLU}(max( ext{wrong}) - ext{correct} + ext{margin})`
- Parameters:
y_pred (array of shape (n, n_classes)) – Predicted class probabilities.
y_true_one_hot (array of shape (n, n_classes)) – One-hot encoded true labels.
margin (float)
- Return type:
scalar
- prosemble.core.losses.neural_gas_energy(distances, lam)[source]¶
Neural Gas energy function.
\[E = \sum_k h( ext{rank}_k, \lambda) \cdot d(x, w_k), \quad h( ext{rank}, \lambda) = \exp(- ext{rank} / \lambda)\]- Parameters:
distances (array of shape (n, p))
lam (float) – Neighborhood range parameter.
- Return type:
scalar
Activations¶
Transfer/activation functions for prototype-based learning.
These functions are used to shape the GLVQ loss (mu values) before summation, controlling the optimization landscape.
- prosemble.core.activations.identity(x, beta=0.0)[source]¶
Identity activation (passthrough).
- Parameters:
x (array) – Input values.
beta (float) – Ignored. Present for API consistency.
- Returns:
Same as input.
- Return type:
array
- prosemble.core.activations.sigmoid_beta(x, beta=10.0)[source]¶
Sigmoid activation with steepness parameter.
f(x) = 1 / (1 + exp(-beta * x))
- Parameters:
x (array) – Input values.
beta (float) – Steepness parameter. Higher values give sharper transition.
- Returns:
Sigmoid-transformed values in (0, 1).
- Return type:
array
Competitions¶
Competition mechanisms for prototype-based classification.
These functions determine class predictions from distance matrices and prototype labels using different strategies.
- prosemble.core.competitions.wtac(distances, prototype_labels)[source]¶
Winner-Takes-All Competition.
Assigns each sample the label of the closest prototype.
- Parameters:
distances (array of shape (n_samples, n_prototypes)) – Distance matrix.
prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.
- Returns:
Predicted labels.
- Return type:
array of shape (n_samples,)
- prosemble.core.competitions.knnc(distances, prototype_labels, k=1, n_classes=None)[source]¶
K-Nearest Neighbors Competition.
Assigns each sample the majority label among k closest prototypes.
- Parameters:
- Returns:
Predicted labels.
- Return type:
array of shape (n_samples,)
- prosemble.core.competitions.cbcc(detections, reasonings)[source]¶
Classification-By-Components Competition.
Computes class probability distributions using component detections and reasoning matrices.
- Parameters:
detections (array of shape (n_samples, n_components)) – Similarity/detection scores for each component.
reasonings (array of shape (n_components, n_classes, 2)) – Reasoning matrices. Last dim: [positive, negative_raw].
- Returns:
Class probability distributions.
- Return type:
array of shape (n_samples, n_classes)
Initializers¶
Prototype and parameter initializers for prototype-based learning.
These functions generate initial prototypes, labels, and transformation matrices for supervised and unsupervised models.
- prosemble.core.initializers.stratified_selection_init(X, y, n_per_class, key)[source]¶
Initialize prototypes by randomly selecting samples per class.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
y (array of shape (n_samples,)) – Training labels.
n_per_class (int, list, or dict) – Number of prototypes per class. - int: same count for all classes. - list: index i gives the count for class i, e.g.
[2, 2, 1]. - dict: maps class label to count, e.g.{0: 2, 1: 2, 2: 1}.key (jax.random.PRNGKey) – Random key.
- Returns:
prototypes (array of shape (n_prototypes, n_features))
prototype_labels (array of shape (n_prototypes,))
- prosemble.core.initializers.stratified_mean_init(X, y)[source]¶
Initialize prototypes at the mean of each class.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
y (array of shape (n_samples,)) – Training labels.
- Returns:
prototypes (array of shape (n_classes, n_features))
prototype_labels (array of shape (n_classes,))
- prosemble.core.initializers.random_normal_init(n_prototypes, n_features, key, mean=0.0, std=1.0)[source]¶
Initialize prototypes from a normal distribution.
- prosemble.core.initializers.identity_omega_init(n_features, n_dims=None)[source]¶
Initialize omega as an identity (or truncated identity) matrix.
- prosemble.core.initializers.random_omega_init(n_features, n_dims, key)[source]¶
Initialize omega as a random orthogonal matrix via QR decomposition.
- prosemble.core.initializers.uniform_init(n_prototypes, n_features, key, low=0.0, high=1.0)[source]¶
Initialize prototypes from a uniform distribution.
- prosemble.core.initializers.zeros_init(n_prototypes, n_features)[source]¶
Initialize prototypes as zeros.
Useful for checkpoint loading where shapes must be pre-allocated before restoring saved values.
- prosemble.core.initializers.ones_init(n_prototypes, n_features)[source]¶
Initialize prototypes as ones.
- prosemble.core.initializers.fill_value_init(n_prototypes, n_features, value=0.0)[source]¶
Initialize prototypes filled with a constant value.
- prosemble.core.initializers.selection_init(X, n_prototypes, key)[source]¶
Initialize prototypes by uniformly sampling from data (classless).
Suitable for unsupervised models like Neural Gas, SOM, and fuzzy clustering where no labels are available.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
n_prototypes (int) – Number of prototypes to select.
key (jax.random.PRNGKey)
- Return type:
array of shape (n_prototypes, n_features)
- prosemble.core.initializers.mean_init(X, n_prototypes)[source]¶
Initialize all prototypes at the data mean (classless).
Suitable for unsupervised models. All prototypes start at the global mean and diverge during training.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
n_prototypes (int) – Number of prototypes.
- Return type:
array of shape (n_prototypes, n_features)
- prosemble.core.initializers.literal_init(values)[source]¶
Initialize prototypes from literal values.
Used for warm-starting from another model’s prototypes or from user-provided values.
- Parameters:
values (array-like of shape (n_prototypes, n_features)) – Literal prototype values.
- Return type:
array of shape (n_prototypes, n_features)
- prosemble.core.initializers.stratified_noise_init(X, y, n_per_class, key, noise_std=0.1)[source]¶
Initialize prototypes by selecting samples per class and adding noise.
Combines stratified selection with Gaussian noise injection for diverse initial prototypes.
- Parameters:
- Returns:
prototypes (array of shape (n_prototypes, n_features))
prototype_labels (array of shape (n_prototypes,))
- prosemble.core.initializers.pca_omega_init(X, n_dims)[source]¶
Initialize omega using PCA directions from training data.
The top-n_dims principal components become the columns of omega, providing a data-driven initialization for metric learning models.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
n_dims (int) – Number of principal components (projection dimensionality).
- Returns:
omega
- Return type:
array of shape (n_features, n_dims)
- prosemble.core.initializers.class_conditional_mean_init(X, y, n_per_class)[source]¶
Initialize prototypes at class means, replicated per n_per_class.
When n_per_class > 1, each class gets multiple prototypes all initialized at the class mean (they will diverge during training).
- prosemble.core.initializers.random_reasonings_init(n_components, n_classes, key)[source]¶
Initialize CBC reasoning matrices randomly.
Pooling¶
Stratified pooling operations for prototype-based learning.
These functions aggregate per-prototype distances into per-class distances, grouping by prototype label. Essential for GLVQ and CELVQ.
- prosemble.core.pooling.stratified_min_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class minimum distance pooling.
For each sample and each class, returns the minimum distance to any prototype of that class.
- Parameters:
distances (array of shape (n_samples, n_prototypes)) – Distance matrix.
prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.
n_classes (int) – Number of classes.
- Returns:
Minimum distance to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
- prosemble.core.pooling.stratified_sum_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class sum distance pooling.
- Parameters:
distances (array of shape (n_samples, n_prototypes))
prototype_labels (array of shape (n_prototypes,))
n_classes (int)
- Returns:
Sum of distances to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
- prosemble.core.pooling.stratified_max_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class maximum distance pooling.
- Parameters:
distances (array of shape (n_samples, n_prototypes))
prototype_labels (array of shape (n_prototypes,))
n_classes (int)
- Returns:
Maximum distance to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
- prosemble.core.pooling.stratified_prod_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class product distance pooling.
Uses log-sum-exp for numerical stability.
- Parameters:
distances (array of shape (n_samples, n_prototypes))
prototype_labels (array of shape (n_prototypes,))
n_classes (int)
- Returns:
Product of distances to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
Quantization¶
Shared mixins for model base classes.
- class prosemble.core.quantization.MetadataCollectorMixin[source]¶
Mixin that auto-collects _hyperparams and _fitted_array_names from MRO.
Any base class that declares per-class
_hyperparamsand_fitted_array_namestuples can inherit this mixin to get_all_hyperparamsand_all_fitted_array_namesaggregated automatically across the entire class hierarchy.
- class prosemble.core.quantization.QuantizationMixin[source]¶
Mixin for quantizing/dequantizing fitted model parameters.
Supports float16, bfloat16, and int8 (with per-tensor scale factors).
Subclasses override
_get_quantizable_attrsto declare which fitted attributes are eligible for quantization.
Utilities¶
JAX utility functions for data preprocessing and manipulation.
Replaces common sklearn/numpy operations with JAX-native implementations.
- prosemble.core.utils.train_test_split_jax(X, y=None, test_size=0.2, random_seed=42)[source]¶
JAX-native train/test split.
- Parameters:
- Returns:
Split arrays. If y is provided, returns 4 arrays, else 2.
- Return type:
X_train, X_test[, y_train, y_test]
Examples
>>> X = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> y = jnp.array([0, 1, 0, 1]) >>> X_train, X_test, y_train, y_test = train_test_split_jax(X, y, test_size=0.5)
- prosemble.core.utils.standardize(X, mean=None, std=None)[source]¶
Standardize features (zero mean, unit variance).
- Parameters:
X (array-like of shape (n_samples, n_features)) – Data to standardize
mean (array-like of shape (n_features,), optional) – Pre-computed mean (for test data)
std (array-like of shape (n_features,), optional) – Pre-computed std (for test data)
- Returns:
X_scaled (array-like) – Standardized data
mean (array-like) – Mean used for scaling
std (array-like) – Std used for scaling
- Return type:
tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]
Examples
>>> X_train = jnp.array([[1, 2], [3, 4], [5, 6]]) >>> X_scaled, mean, std = standardize(X_train) >>> # For test data >>> X_test_scaled, _, _ = standardize(X_test, mean=mean, std=std)
- prosemble.core.utils.min_max_scale(X, min_val=None, max_val=None)[source]¶
Scale features to [0, 1] range.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Data to scale
min_val (array-like of shape (n_features,), optional) – Pre-computed min (for test data)
max_val (array-like of shape (n_features,), optional) – Pre-computed max (for test data)
- Returns:
X_scaled (array-like) – Scaled data
min_val (array-like) – Min values used
max_val (array-like) – Max values used
- Return type:
tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]
- prosemble.core.utils.pca_jax(X, n_components=2)[source]¶
Principal Component Analysis using JAX.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Data matrix
n_components (int, default=2) – Number of principal components
- Returns:
X_transformed (array-like of shape (n_samples, n_components)) – Transformed data
components (array-like of shape (n_components, n_features)) – Principal components
- Return type:
tuple[Array | ndarray | bool | number, Array | ndarray | bool | number]
Examples
>>> X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> X_pca, components = pca_jax(X, n_components=2)
- prosemble.core.utils.accuracy_score_jax(y_true, y_pred)[source]¶
Compute classification accuracy.
- Parameters:
y_true (array-like of shape (n_samples,)) – True labels
y_pred (array-like of shape (n_samples,)) – Predicted labels
- Returns:
accuracy – Accuracy score
- Return type:
Examples
>>> y_true = jnp.array([0, 1, 1, 0]) >>> y_pred = jnp.array([0, 1, 0, 0]) >>> accuracy_score_jax(y_true, y_pred) 0.75
- prosemble.core.utils.confusion_matrix_jax(y_true, y_pred, n_classes)[source]¶
Compute confusion matrix.
- Parameters:
y_true (array-like of shape (n_samples,)) – True labels
y_pred (array-like of shape (n_samples,)) – Predicted labels
n_classes (int) – Number of classes
- Returns:
conf_matrix – Confusion matrix
- Return type:
array-like of shape (n_classes, n_classes)
Examples
>>> y_true = jnp.array([0, 1, 2, 0, 1, 2]) >>> y_pred = jnp.array([0, 2, 2, 0, 0, 1]) >>> confusion_matrix_jax(y_true, y_pred, n_classes=3)
- prosemble.core.utils.shuffle_jax(X, y=None, random_seed=42)[source]¶
Shuffle arrays in unison.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Feature matrix
y (array-like of shape (n_samples,), optional) – Labels
random_seed (int, default=42) – Random seed
- Returns:
Shuffled arrays
- Return type:
X_shuffled[, y_shuffled]
Examples
>>> X = jnp.array([[1, 2], [3, 4], [5, 6]]) >>> y = jnp.array([0, 1, 0]) >>> X_shuf, y_shuf = shuffle_jax(X, y, random_seed=42)
- prosemble.core.utils.k_fold_split_jax(n_samples, n_folds=5, random_seed=42)[source]¶
Generate K-fold cross-validation indices.
- Parameters:
- Yields:
train_indices, test_indices – Indices for each fold
Examples
>>> for train_idx, test_idx in k_fold_split_jax(100, n_folds=5): ... X_train, X_test = X[train_idx], X[test_idx]
- prosemble.core.utils.orthogonalize(matrix)[source]¶
Orthogonalize a matrix via polar decomposition (SVD).
Given a matrix A of shape (d, s), computes the closest orthogonal matrix Q such that Q^T Q = I, using the polar factor:
U, S, V^T = SVD(A) Q = U @ V^T
Supports batched input via jax.vmap.
- Parameters:
matrix (array of shape (d, s) or (n, d, s)) – Matrix or batch of matrices to orthogonalize. For batched input, use
jax.vmap(orthogonalize).- Returns:
Q – Orthogonal matrix (columns are orthonormal).
- Return type:
array of same shape as input
- prosemble.core.utils.class_priors(labels, n_classes=None)[source]¶
Compute class prior probabilities from labels.
P(class=k) = n_k / n
- Parameters:
labels (array of shape (n_samples,)) – Integer class labels.
n_classes (int, optional) – Number of classes. Inferred from labels if not provided.
- Returns:
priors – Prior probability for each class, sums to 1.
- Return type:
array of shape (n_classes,)
- prosemble.core.utils.prototype_priors(prototype_labels, n_classes=None)[source]¶
Compute class priors from prototype label distribution.
Used in probabilistic LVQ models where the prior over prototypes is uniform (1/n_prototypes) and the class prior is the fraction of prototypes assigned to each class.
P(class=k) = |{j : label_j = k}| / n_prototypes- Parameters:
prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.
n_classes (int, optional) – Number of classes. Inferred if not provided.
- Returns:
priors – Class prior probabilities, sums to 1.
- Return type:
array of shape (n_classes,)
- prosemble.core.utils.uniform_prototype_prior(n_prototypes)[source]¶
Uniform prior over prototypes: P(prototype_j) = 1/n.
This is the standard prior used in probabilistic LVQ (SLVQ/RSLVQ).
- Parameters:
n_prototypes (int) – Number of prototypes.
- Returns:
prior – Uniform probability vector, sums to 1.
- Return type:
array of shape (n_prototypes,)
Pipeline¶
- class prosemble.core.pipeline.NotFittedError[source]¶
Raised when calling transform/predict on an unfitted estimator.
- class prosemble.core.pipeline.StandardScaler[source]¶
Standardize features to zero mean and unit variance.
Wraps
prosemble.core.utils.standardize().Examples
>>> scaler = StandardScaler() >>> X_scaled = scaler.fit_transform(X_train) >>> X_test_scaled = scaler.transform(X_test)
- class prosemble.core.pipeline.MinMaxScaler[source]¶
Scale features to [0, 1] range.
Wraps
prosemble.core.utils.min_max_scale().Examples
>>> scaler = MinMaxScaler() >>> X_scaled = scaler.fit_transform(X_train)
- class prosemble.core.pipeline.PCA(n_components=2)[source]¶
Principal Component Analysis.
Wraps
prosemble.core.utils.pca_jax().- Parameters:
n_components (int, default=2) – Number of principal components to keep.
Examples
>>> pca = PCA(n_components=2) >>> X_reduced = pca.fit_transform(X)
- class prosemble.core.pipeline.Pipeline(steps)[source]¶
Chain transformers with a final estimator.
All operations use pure JAX arrays — no numpy roundtrips.
- Parameters:
steps (list of (name, estimator) tuples) – Sequence of transforms with a final estimator. All but the last must implement
fit()andtransform(). The last step can be any estimator or transformer.
Examples
>>> pipe = Pipeline([ ... ('scaler', StandardScaler()), ... ('pca', PCA(n_components=2)), ... ('model', GLVQ(n_prototypes_per_class=1, max_iter=50)), ... ]) >>> pipe.fit(X_train, y_train) >>> preds = pipe.predict(X_test)
- fit(X, y=None, **fit_params)[source]¶
Fit all steps.
Calls fit_transform on intermediate steps, fit on the final step.
- Parameters:
X (array of shape (n_samples, n_features))
y (array of shape (n_samples,), optional)
**fit_params (forwarded to the final estimator's fit().)
- Return type:
Model Selection¶
- prosemble.core.model_selection.clone(estimator, **override_params)[source]¶
Create a fresh (unfitted) copy of an estimator.
- Parameters:
estimator (object) – Estimator with
get_params()protocol.**override_params – Parameters to override in the clone.
- Returns:
new_estimator – Fresh, unfitted instance.
- Return type:
same type as estimator
Examples
>>> model = GLVQ(n_prototypes_per_class=2, lr=0.01) >>> model2 = clone(model, lr=0.05)
- prosemble.core.model_selection.cross_val_score(estimator, X, y=None, cv=5, scoring='accuracy', random_seed=42)[source]¶
Evaluate estimator with cross-validation.
- Parameters:
estimator (object) – Estimator with fit/predict and get_params.
X (array of shape (n_samples, n_features))
y (array of shape (n_samples,), optional) – Required for supervised estimators.
cv (int, default=5) – Number of cross-validation folds.
scoring (str or callable, default='accuracy') – ‘accuracy’ or callable
scorer(y_true, y_pred) -> float. For unsupervised without y:scorer(estimator, X_test) -> float.random_seed (int, default=42)
- Returns:
scores – Score for each fold.
- Return type:
jnp.ndarray of shape (cv,)
Examples
>>> scores = cross_val_score(GLVQ(max_iter=30), X, y, cv=5) >>> print(f"Mean: {scores.mean():.3f} +/- {scores.std():.3f}")
- class prosemble.core.model_selection.GridSearchCV(estimator, param_grid, cv=5, scoring='accuracy', random_seed=42, refit=True, verbose=0)[source]¶
Exhaustive search over a parameter grid with cross-validation.
- Parameters:
estimator (object) – Base estimator with fit/predict and get_params/set_params. Can be a Pipeline or any prosemble model.
param_grid (dict) – Maps parameter names to lists of values to try. For Pipeline steps, use
step_name__paramnotation.cv (int, default=5) – Number of cross-validation folds.
scoring (str or callable, default='accuracy') – ‘accuracy’ or callable
scorer(y_true, y_pred) -> float.random_seed (int, default=42)
refit (bool, default=True) – If True, refit the best model on the full dataset after search.
verbose (int, default=0) – 0=silent, 1=per-combo summary, 2=per-fold detail.
Examples
>>> gs = GridSearchCV( ... GLVQ(max_iter=30), ... {'n_prototypes_per_class': [1, 2], 'lr': [0.01, 0.05]}, ... cv=3, ... ) >>> gs.fit(X, y) >>> print(gs.best_params_, gs.best_score_)
Datasets¶
JAX-compatible dataset module.
Provides dataset loaders that return JAX arrays instead of NumPy arrays, optimized for use with JAX-based clustering models.
- class prosemble.datasets.dataset.DATASET_JAX(input_data, labels)[source]
JAX-compatible dataset container.
- input_data
Feature matrix as JAX array
- Type:
jax.numpy.ndarray
- labels
Labels as JAX array
- Type:
jax.numpy.ndarray
- input_data: Array
- labels: Array
- to_numpy()[source]
Convert JAX arrays back to NumPy arrays.
- prosemble.datasets.dataset.iris_dataset_jax(dtype=None)[source]
Load Iris dataset as JAX arrays.
- Parameters:
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (150 samples, 4 features, 3 classes)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import iris_dataset_jax >>> dataset = iris_dataset_jax() >>> dataset.input_data.shape (150, 4)
- prosemble.datasets.dataset.breast_cancer_dataset_jax(dtype=None)[source]
Load Wisconsin Breast Cancer dataset as JAX arrays.
- Parameters:
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (569 samples, 30 features, 2 classes)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import breast_cancer_dataset_jax >>> dataset = breast_cancer_dataset_jax() >>> dataset.input_data.shape (569, 30)
- prosemble.datasets.dataset.moons_dataset_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]
Generate two interleaving half circles (moons) as JAX arrays.
- Parameters:
n_samples (int, default=150) – Total number of points generated
noise (float, optional) – Standard deviation of Gaussian noise added to data
random_state (int, optional) – Random seed for reproducibility
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (n_samples, 2 features, 2 classes)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import moons_dataset_jax >>> dataset = moons_dataset_jax(n_samples=200, noise=0.1, random_state=42) >>> dataset.input_data.shape (200, 2)
- prosemble.datasets.dataset.blobs_dataset_jax(n_samples=None, centers=None, cluster_std=None, random_state=None, dtype=None)[source]
Generate isotropic Gaussian blobs as JAX arrays.
- Parameters:
n_samples (list, default=[120, 80]) – Number of samples per cluster
centers (list, default=[[0.0, 0.0], [2.0, 2.0]]) – Centers of the clusters
cluster_std (list, default=[1.2, 0.5]) – Standard deviation of clusters
random_state (int, optional) – Random seed for reproducibility
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (sum(n_samples), 2 features)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import blobs_dataset_jax >>> dataset = blobs_dataset_jax(random_state=42) >>> dataset.input_data.shape (200, 2)
- class prosemble.datasets.dataset.DATA_JAX(random=4, sample_size=1000)[source]
JAX dataset collection with convenient property access.
- Parameters:
random (int, default=4) – Random seed for dataset generation
sample_size (int, default=1000) – Default sample size (not currently used)
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
Examples
>>> from prosemble.datasets import DATA_JAX >>> data = DATA_JAX(random=42) >>> moons = data.S_1 # Moons dataset >>> blobs = data.S_2 # Blobs dataset >>> cancer = data.breast_cancer # Breast cancer dataset
- random: int = 4
- sample_size: int = 1000
- dtype
alias of
float32
- property S_1: DATASET_JAX
Moons dataset.
- property S_2: DATASET_JAX
Blobs dataset.
- property iris: DATASET_JAX
Iris dataset.
- property breast_cancer: DATASET_JAX
Breast cancer dataset.
- property moons: DATASET_JAX
Alias for S_1 (moons dataset).
- property blobs: DATASET_JAX
Alias for S_2 (blobs dataset).
- prosemble.datasets.dataset.load_iris_jax(dtype=None)[source]
Quick loader for iris dataset.
- Return type:
DATASET_JAX
- prosemble.datasets.dataset.load_breast_cancer_jax(dtype=None)[source]
Quick loader for breast cancer dataset.
- Return type:
DATASET_JAX
- prosemble.datasets.dataset.load_moons_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]
Quick loader for moons dataset.
- prosemble.datasets.dataset.load_blobs_jax(random_state=None, dtype=None)[source]
Quick loader for blobs dataset.
- Parameters:
random_state (int | None)
- Return type:
DATASET_JAX
- prosemble.datasets.dataset.DATA
alias of
DATA_JAX
ONNX Export¶
ONNX export for prosemble prototype-based models.
Converts a fitted model’s predict function into an ONNX graph.
Only supports models whose distance function can be expressed with
standard ONNX operators. Unsupported models raise
NotImplementedError with a clear message.
Supported distance functions:
squared_euclidean_distance_matrixeuclidean_distance_matrixmanhattan_distance_matrixomega_distance_matrix(global projection matrix)lomega_distance_matrix(per-prototype local matrices)tangent_distance_matrix(per-prototype tangent subspace)relevance_weighted(per-feature relevance weighting)
Supported decision patterns:
WTAC (supervised classification)
ArgMin (unsupervised clustering)
One-class hard nearest (OCGLVQ family)
One-class Gaussian soft (OCRSLVQ family)
One-class Gaussian+NG soft (OCRSLVQ_NG family)
SVQ-OCC response model (SVQOCC family)
CBC reasoning (CBC, ImageCBC)
PLVQ Gaussian mixture soft assignment
Supported encoder models:
MLP encoder (SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN, PLVQ)
CNN encoder (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC)
Not supported:
gaussian_kernel_matrix,polynomial_kernel_matrixRiemannian manifold distances (logm, expm have no ONNX equivalent)
- prosemble.core.onnx_export.export_onnx(model, batch_size=1, opset_version=17, path=None, reject_threshold=None)[source]¶
Export a fitted model’s predict function to ONNX format.
Builds an ONNX graph that reproduces the model’s
predict()output. Supports supervised (WTAC), unsupervised (ArgMin), one-class (threshold-based), SVQ-OCC (response model), CBC (reasoning matrices), PLVQ (Gaussian mixture), encoder models (MLP/CNN backbone), Differentiating Kernel models (Gaussian, relevance-weighted, and exponential kernels), and Riemannian models on SO(n)/Grassmannian manifolds (88 of 102 models total).- Parameters:
model (SupervisedPrototypeModel or UnsupervisedPrototypeModel) – A fitted prosemble model.
batch_size (int) – Fixed batch dimension for the input. Use
-1for dynamic batch size (ONNX symbolic dimension).opset_version (int) – ONNX opset version. Default: 17.
path (str, optional) – If provided, save the ONNX model to this file path.
reject_threshold (float, optional) – If provided, enables reject option for supervised models. Samples with confidence below this threshold are assigned prediction -1 (rejected). The exported model produces two outputs:
predictions(INT64, with -1 for rejected) andconfidence(FLOAT, in [-1, 1]). Only supported for supervised models (WTAC decision).
- Returns:
The exported ONNX model.
- Return type:
onnx.ModelProto
- Raises:
NotImplementedError – If the model’s distance function or decision pattern is not supported.
ImportError – If the
onnxpackage is not installed.ValueError – If
reject_thresholdis used with a non-supervised model.
Manifolds¶
Riemannian manifold primitives for prototype learning.
Provides geodesic distance, exponential map, and logarithmic map for manifolds commonly arising in machine learning:
SO(n): Special orthogonal group (rotation matrices)
SPD(n): Symmetric positive definite matrices
Grassmannian Gr(n, k): k-dimensional subspaces of R^n
HyperbolicPoincare(d): Poincare ball model of hyperbolic space
All operations are JIT-compilable and vmap-compatible.
References
Schwarz, Psenickova, Villmann, Röhrbein (2026). Topology-Preserving Prototype Learning on Riemannian Manifolds. ESANN 2026.
Ganea, O., Becigneul, G., & Hofmann, T. (2018). Hyperbolic Neural Networks. NeurIPS 2018.
- prosemble.core.manifolds.logm_safe(A)[source]¶
Matrix logarithm via funm, with numerical safety.
Uses JAX’s
funmto compute logm. For complex eigenvalues (e.g. rotation matrices), operates in complex domain and returns the real part.- Parameters:
A (array of shape (..., n, n))
- Returns:
logA
- Return type:
array of shape (…, n, n)
- prosemble.core.manifolds.sqrt_spd(A)[source]¶
Matrix square root for symmetric positive definite matrices.
Uses eigendecomposition: \(A^{1/2} = V \operatorname{diag}(\sqrt{\lambda}) V^T\).
- Parameters:
A (array of shape (n, n), symmetric positive definite)
- Returns:
A_sqrt
- Return type:
array of shape (n, n)
- prosemble.core.manifolds.inv_sqrt_spd(A)[source]¶
Inverse matrix square root for symmetric positive definite matrices.
Uses eigendecomposition: \(A^{-1/2} = V \operatorname{diag}(1/\sqrt{\lambda}) V^T\).
- Parameters:
A (array of shape (n, n), symmetric positive definite)
- Returns:
A_inv_sqrt
- Return type:
array of shape (n, n)
- class prosemble.core.manifolds.SO(n)[source]¶
Special orthogonal group SO(n): rotation matrices.
Points are \(n imes n\) orthogonal matrices with det = +1. Geodesic distance uses the bi-invariant metric.
- Parameters:
n (int) – Dimension of the rotation group.
- log_map(R, S)[source]¶
Logarithmic map: Log_R(S) = R @ logm(R^T @ S).
Maps point S on the manifold to a tangent vector at R.
- exp_map(R, V)[source]¶
Exponential map: Exp_R(V) = R @ expm(R^T @ V).
Maps tangent vector V at R back to the manifold.
- class prosemble.core.manifolds.SPD(n)[source]¶
Manifold of \(n \times n\) symmetric positive definite matrices.
Uses the affine-invariant Riemannian metric.
- Parameters:
n (int) – Matrix dimension.
- class prosemble.core.manifolds.Grassmannian(n, k)[source]¶
Grassmannian manifold Gr(n, k): k-dimensional subspaces of R^n.
Points are represented as orthonormal bases Q of shape (n, k) with Q^T Q = I_k.
- distance(Q1, Q2)[source]¶
Geodesic distance via principal angles.
d(Q1, Q2) = ||theta||_2 where theta = arccos(svd(Q1^T Q2)).
- log_map(Q1, Q2)[source]¶
Logarithmic map on the Grassmannian.
Computes the tangent vector at Q1 pointing toward Q2 using aligned principal angle decomposition (Edelman et al. 1998).
The key insight is to align both subspaces via the SVD of Q1^T Q2, ensuring the principal angles and directions correspond.
- class prosemble.core.manifolds.HyperbolicPoincare(d, eps=1e-05)[source]¶
Poincaré ball model of hyperbolic space \(\mathbb{H}^d\).
Points are vectors in \(\mathbb{B}^d = \{x \in \mathbb{R}^d : \|x\| < 1\}\). The Riemannian metric is \(g_x = \lambda_x^2 g_E\) where \(\lambda_x = 2/(1-\|x\|^2)\) is the conformal factor.
This manifold is the natural geometry for hierarchical and tree-structured data — it can embed trees with arbitrarily low distortion (Sarkar 2011).
- Parameters:
- distance(x, y)[source]¶
Geodesic distance in the Poincaré ball.
\[d(x, y) = \operatorname{arcosh}\!\left(1 + \frac{2\|x - y\|^2} {(1 - \|x\|^2)(1 - \|y\|^2)}\right)\]
- log_map(x, y)[source]¶
Logarithmic map: tangent vector at x pointing toward y.
\[\text{Log}_x(y) = \frac{2}{\lambda_x} \operatorname{arctanh}(\|-x \oplus y\|) \frac{-x \oplus y}{\|-x \oplus y\|}\]
- exp_map(x, v)[source]¶
Exponential map: move from x along tangent vector v.
\[\text{Exp}_x(v) = x \oplus \left( \tanh\!\left(\frac{\lambda_x \|v\|}{2}\right) \frac{v}{\|v\|}\right)\]
- random_point(key)[source]¶
Generate a random point uniformly in the Poincaré ball.
Uses the radial transform: sample direction uniformly on S^{d-1}, then radius r ~ Beta(1, d) to get uniform distribution in B^d, scaled to stay safely inside the ball.
Protocols¶
Structural typing protocols for prosemble interfaces.
Defines typing.Protocol contracts for duck-typed interfaces used
across the library. These enable static type checking (mypy / pyright)
and IDE auto-completion without requiring inheritance.
Note
Runtime-checkable protocols (Manifold, CallbackLike) support
isinstance() checks. Type aliases (DistanceMatrixFn, etc.)
are for annotation only.
- class prosemble.core.protocols.Manifold(*args, **kwargs)[source]¶
Protocol for Riemannian manifold implementations.
Any object exposing the methods below can be used wherever a manifold is expected (e.g.
RiemannianNeuralGas,RiemannianSRNG).The concrete implementations
SO,SPD, andGrassmannianall satisfy this protocol structurally — no explicit subclassing is required.- distance(p, q)[source]¶
Geodesic distance between two points.
- Parameters:
p (arrays of shape
point_shape)q (arrays of shape
point_shape)
- Return type:
scalar
- distance_squared(p, q)[source]¶
Squared geodesic distance between two points.
- Parameters:
p (arrays of shape
point_shape)q (arrays of shape
point_shape)
- Return type:
scalar
- log_map(base, target)[source]¶
Logarithmic map: tangent vector at base pointing toward target.
- Parameters:
base (array of shape
point_shape) – Base point on the manifold.target (array of shape
point_shape) – Target point on the manifold.
- Returns:
tangent – Tangent vector in \(T_{\text{base}} M\).
- Return type:
array of shape
point_shape
- exp_map(base, tangent)[source]¶
Exponential map: move along tangent from base back to the manifold.
- Parameters:
base (array of shape
point_shape)tangent (array of shape
point_shape)
- Returns:
point
- Return type:
array of shape
point_shape
- random_point(key)[source]¶
Sample a random point on the manifold.
- Parameters:
key (JAX PRNG key)
- Returns:
point
- Return type:
array of shape
point_shape
- belongs(point)[source]¶
Check whether point lies on the manifold.
- Parameters:
point (array of shape
point_shape)- Return type:
bool or bool-valued array
- class prosemble.core.protocols.CallbackLike(*args, **kwargs)[source]¶
Protocol for training callbacks.
Any object with the three hook methods below can be passed in the
callbackslist of a model’s constructor. The existingCallbackbase class already satisfies this protocol.
- prosemble.core.protocols.DistanceMatrixFn¶
(X, Y) -> distances.Xhas shape(n_samples, n_features),Yhas shape(n_prototypes, n_features), result has shape(n_samples, n_prototypes).- Type:
Distance-matrix function
- prosemble.core.protocols.DistancePairwiseFn¶
(x, y) -> scalar.- Type:
Pairwise distance function
Distributed Training¶
Distributed training utilities for multi-device data parallelism.
Provides functions for sharding data across devices and replicating model parameters, enabling data-parallel training on multi-GPU/TPU setups.
- prosemble.core.distributed.create_mesh(devices=None)[source]¶
Create a 1D device mesh for data parallelism.
- Parameters:
devices (list of jax.Device or None) – Devices to use. If None, returns None (single-device mode).
- Return type:
Mesh or None
- prosemble.core.distributed.shard_data(X, y, mesh)[source]¶
Shard data arrays along the batch dimension across devices.
- Parameters:
X (jnp.ndarray of shape (n_samples, ...)) – Input data.
y (jnp.ndarray of shape (n_samples,)) – Labels.
mesh (Mesh) – Device mesh from create_mesh().
- Returns:
X_sharded, y_sharded
- Return type:
tuple of sharded arrays
- prosemble.core.distributed.replicate_params(params, mesh)[source]¶
Replicate params across all devices (no partitioning).
- Parameters:
params (dict (pytree)) – Model parameters.
mesh (Mesh) – Device mesh from create_mesh().
- Return type:
params replicated across mesh
Data Utilities¶
Data loading and batching utilities for prosemble.
Provides composable primitives for mini-batch iteration, suitable for
custom training loops. The built-in fit() methods already handle
batching internally — these utilities are for advanced users who need
explicit control over data iteration.
- prosemble.core.data.shuffle_arrays(key, *arrays)[source]
Shuffle multiple arrays with the same random permutation.
- Parameters:
key (JAX PRNG key) – Random key for generating the permutation.
*arrays (jnp.ndarray) – Arrays to shuffle. All must have the same length along axis 0.
- Returns:
Shuffled arrays in the same order as the inputs.
- Return type:
tuple of jnp.ndarray
Examples
>>> key = jax.random.PRNGKey(0) >>> X = jnp.arange(12).reshape(4, 3) >>> y = jnp.array([0, 1, 0, 1]) >>> X_s, y_s = shuffle_arrays(key, X, y)
- prosemble.core.data.padded_batches(X, y=None, batch_size=32, key=None)[source]
Split data into static-shaped batches, padding the last batch.
Returns arrays of shape
(n_batches, batch_size, ...)suitable forjax.lax.scan. If the data length is not divisible bybatch_size, the last batch is padded by repeating initial samples.- Parameters:
X (jnp.ndarray of shape (n_samples, ...)) – Feature array.
y (jnp.ndarray of shape (n_samples,), optional) – Label array.
batch_size (int) – Number of samples per batch.
key (JAX PRNG key, optional) – If provided, data is shuffled before batching.
- Returns:
X_batches (jnp.ndarray of shape (n_batches, batch_size, …))
y_batches (jnp.ndarray of shape (n_batches, batch_size) or None)
- Return type:
Examples
>>> X = jnp.ones((10, 3)) >>> y = jnp.arange(10) >>> X_b, y_b = padded_batches(X, y, batch_size=4) >>> X_b.shape # (3, 4, 3) — 10 samples padded to 12 (3, 4, 3)
- prosemble.core.data.batched_iterator(X, y=None, batch_size=32, shuffle=True, key=None, drop_last=False)[source]
Yield mini-batches from data arrays.
For use in custom Python training loops.
- Parameters:
X (jnp.ndarray of shape (n_samples, ...)) – Feature array.
y (jnp.ndarray of shape (n_samples,), optional) – Label array.
batch_size (int) – Number of samples per batch.
shuffle (bool) – If True and key is provided, shuffle before iterating.
key (JAX PRNG key, optional) – Required if
shuffle=True.drop_last (bool) – If True, drop the last batch when it is smaller than
batch_size. If False (default), the last batch may be smaller.
- Yields:
X_batch (jnp.ndarray of shape (batch_size, …) or smaller)
y_batch (jnp.ndarray of shape (batch_size,) or None)
Examples
>>> key = jax.random.PRNGKey(0) >>> X = jnp.ones((10, 3)) >>> for X_b, y_b in batched_iterator(X, batch_size=4, key=key): ... print(X_b.shape) (4, 3) (4, 3) (2, 3)
Custom Optimizers¶
Custom optax-compatible optimizers for prototype-based learning.
Provides specialized gradient transformations designed for the geometry and parameter structure of LVQ models:
per_group_clip: Per-parameter-group gradient norm clipping.hypergradient_descent: Adaptive per-parameter learning rates via gradient correlation (Baydin et al. 2017).riemannian_nesterov: Nesterov accelerated gradient with manifold-aware momentum (parallel transport on Riemannian manifolds).
All transformations follow the optax GradientTransformation interface
and can be composed via optax.chain() or passed directly to any model’s
optimizer parameter.
References
Baydin, A. G., et al. (2017). Online learning rate adaptation with hypergradient descent. arXiv:1703.04782.
Absil, P.-A., Mahony, R., & Sepulchre, R. (2008). Optimization Algorithms on Matrix Manifolds. Princeton University Press.
- class prosemble.core.optimizers.PerGroupClipState[source]¶
State for per-group gradient clipping (stateless).
- prosemble.core.optimizers.per_group_clip(max_norms)[source]¶
Clip gradient norms independently per parameter group.
Different parameter types (prototypes, omega matrices, relevances, sigmas) have different natural scales. A single global clip either under-constrains large parameters or over-constrains small ones. This transformation clips each group independently.
- Parameters:
max_norms (dict) – Mapping from parameter key name to maximum gradient norm. Keys not present in this dict are left unclipped. Example:
{'prototypes': 1.0, 'omega': 0.5, 'sigmas': 0.1}- Returns:
Composable gradient transformation.
- Return type:
optax.GradientTransformation
Examples
>>> import optax >>> from prosemble.core.optimizers import per_group_clip >>> optimizer = optax.chain( ... per_group_clip({'prototypes': 1.0, 'omega': 0.5, 'sigmas': 0.1}), ... optax.adam(0.01), ... )
- class prosemble.core.optimizers.HypergradientState(learning_rates, prev_grads, base_opt_state)[source]¶
State for hypergradient descent optimizer.
- prosemble.core.optimizers.hypergradient_descent(init_lr=0.01, hyper_lr=0.0001, inner_optimizer='sgd', min_lr=1e-06, max_lr=1.0)[source]¶
Adaptive per-parameter learning rates via hypergradient descent.
If consecutive gradients point in the same direction (positive dot product), increase the learning rate. If they oscillate (negative dot product), decrease it. This allows each parameter group to converge at its own optimal rate.
The update rule for learning rate eta_k at step t:
\[\eta_k^{t+1} = \text{clip}\left( \eta_k^t - \beta \cdot \langle g_k^t, g_k^{t-1} \rangle \right)\]- Parameters:
init_lr (float) – Initial learning rate for all parameter groups. Default: 0.01.
hyper_lr (float) – Learning rate for the learning rate update (meta-learning rate). Default: 1e-4.
inner_optimizer (str) – Base optimizer to use (‘sgd’ applies raw scaled gradients). Default: ‘sgd’.
min_lr (float) – Minimum allowed learning rate. Default: 1e-6.
max_lr (float) – Maximum allowed learning rate. Default: 1.0.
- Return type:
optax.GradientTransformation
References
[1] Baydin, A. G., et al. (2017). Online learning rate adaptation with hypergradient descent. arXiv:1703.04782.
Examples
>>> from prosemble.core.optimizers import hypergradient_descent >>> optimizer = hypergradient_descent(init_lr=0.01, hyper_lr=1e-4)
- class prosemble.core.optimizers.RiemannianNesterovState(velocity, step)[source]¶
State for Riemannian Nesterov momentum.
- prosemble.core.optimizers.riemannian_nesterov(learning_rate=0.01, momentum=0.9)[source]¶
Nesterov accelerated gradient adapted for prototype-based models.
Implements Nesterov momentum in Euclidean parameter space. For Riemannian models, the prototypes are stored in flattened form and the manifold projection is handled by
_post_update().The momentum buffer provides O(1/t^2) convergence rate versus O(1/t) for vanilla gradient descent on convex objectives.
Update rule:
\[v_{t+1} = \mu \cdot v_t + g_t \theta_{t+1} = \theta_t - \eta \cdot (\mu \cdot v_{t+1} + g_t)\]This is the Nesterov variant where the lookahead is incorporated into the update (Sutskever et al. 2013 reformulation).
- Parameters:
- Return type:
optax.GradientTransformation
Notes
For Riemannian models where prototypes live on manifolds, the manifold retraction (projection back to manifold) is handled by the model’s
_post_update()method. This optimizer provides the accelerated gradient direction; the model ensures the result stays on the manifold.For true Riemannian Nesterov (with parallel transport), use this optimizer with Riemannian models that implement
_post_updatewith manifold projection.Examples
>>> from prosemble.core.optimizers import riemannian_nesterov >>> optimizer = riemannian_nesterov(learning_rate=0.01, momentum=0.9)
Regularization¶
Regularization techniques for prototype-based models.
Provides reusable loss-level regularization terms that can be composed
with any model’s _compute_loss method:
prototype_diversity_loss: DPP-inspired repulsion between same-class prototypes.sparse_relevance_proximal: Elastic net proximal step for relevance vectors (soft-thresholding).
References
Kulesza, A. & Taskar, B. (2012). Determinantal Point Processes for Machine Learning. Foundations and Trends in Machine Learning.
Tibshirani, R. (1996). Regression shrinkage and selection via the LASSO. JRSS-B.
- prosemble.core.regularization.prototype_diversity_loss(prototypes, proto_labels, sigma_div=1.0)[source]¶
Compute DPP-inspired diversity regularization for same-class prototypes.
Encourages multiple prototypes per class to spread out rather than collapsing to the same point. Uses the log-determinant of the RBF kernel matrix between same-class prototypes.
\[L_{\text{diversity}} = -\sum_c \log\det(K_c + \epsilon I)\]where \(K_c[i,j] = \exp(-\|w_i - w_j\|^2 / (2\sigma^2))\).
The determinant is maximized when prototypes are spread out (DPP theory). When two prototypes collapse, det -> 0, so -log(det) -> inf.
- Parameters:
prototypes (array of shape (n_prototypes, n_features)) – Prototype positions.
proto_labels (array of shape (n_prototypes,)) – Class label for each prototype.
sigma_div (float) – Bandwidth for the RBF kernel. Controls the scale at which diversity is measured. Default: 1.0.
- Returns:
loss – Diversity penalty (lower is more diverse).
- Return type:
scalar
Notes
Only active when n_prototypes_per_class > 1. For single-prototype classes, the contribution is zero (log(det(1x1 identity)) = 0).
- prosemble.core.regularization.prototype_diversity_loss_vectorized(prototypes, proto_labels, sigma_div=1.0, n_classes=None, max_protos_per_class=None)[source]¶
JIT-friendly vectorized diversity loss without Python loops.
For use within
lax.scantraining (use_scan=True). Requires static shape information.- Parameters:
prototypes (array of shape (n_prototypes, n_features))
proto_labels (array of shape (n_prototypes,))
sigma_div (float) – RBF bandwidth. Default: 1.0.
n_classes (int) – Number of classes (must be known at compile time).
max_protos_per_class (int) – Maximum prototypes in any class (must be known at compile time).
- Returns:
loss
- Return type:
scalar
- prosemble.core.regularization.sparse_relevance_proximal(relevances, l1_weight, lr=1.0)[source]¶
Apply proximal operator for L1 regularization (soft-thresholding).
Enforces true sparsity in relevance vectors by applying the proximal mapping of the L1 norm after the gradient step:
\[\text{prox}_{\alpha\|\cdot\|_1}(x)_j = \text{sign}(x_j) \cdot \max(|x_j| - \alpha, 0)\]This gives genuinely sparse feature selection with LASSO consistency guarantees.
- Parameters:
relevances (array of shape (n_features,) or (n_prototypes, n_features)) – Relevance logits (before softmax normalization).
l1_weight (float) – L1 penalty strength. Larger values give sparser solutions.
lr (float) – Learning rate (used to scale the threshold: threshold = l1_weight * lr). Default: 1.0.
- Returns:
sparse_relevances – Thresholded relevances with exact zeros.
- Return type:
array, same shape as input
- prosemble.core.regularization.elastic_net_proximal(relevances, l1_weight, l2_weight, lr=1.0)[source]¶
Apply elastic net proximal operator (L1 + L2 regularization).
Combines soft-thresholding (L1 sparsity) with L2 shrinkage:
\[\text{prox}(x)_j = \frac{\text{sign}(x_j) \max(|x_j| - \alpha, 0)} {1 + \beta}\]where alpha = l1_weight * lr, beta = l2_weight * lr.
Curriculum Learning¶
Curriculum learning (self-paced learning) for prototype-based models.
Implements the self-paced learning framework (Kumar et al. 2010) adapted for LVQ models. Samples are presented in order of difficulty — easy samples first, hard samples later.
For LVQ, sample difficulty is measured by the absolute value of the mu-ratio: |mu(x)| close to 0 means the sample is on the decision boundary (hard), while |mu(x)| >> 0 means it’s far from the boundary (easy).
Usage¶
Curriculum learning is applied via sample masking in the loss function. At each training step, a difficulty threshold lambda_t determines which samples participate. The threshold increases over training to gradually include harder samples.
References
Kumar, M. P., Packer, B., & Koller, D. (2010). Self-paced learning for latent variable models. NeurIPS.
Bengio, Y., Louradour, J., Collobert, R., & Weston, J. (2009). Curriculum learning. ICML.
- prosemble.core.curriculum.curriculum_weights(per_sample_loss, threshold, mode='hard')[source]¶
Compute per-sample curriculum weights based on loss magnitude.
- Parameters:
per_sample_loss (array of shape (n_samples,)) – Per-sample loss values (e.g., from GLVQ mu-ratio).
threshold (float) – Difficulty threshold lambda_t. Samples with loss below this threshold are included. Increases over training.
mode (str, {'hard', 'soft', 'linear'}) –
Weight assignment mode: - ‘hard’: binary weights (0 or 1). Sample is either in or out. - ‘soft’: smooth exponential weighting.
w_i = exp(-loss_i / threshold) if loss_i > threshold, else 1.
’linear’: linear ramp-down. w_i = max(0, 1 - (loss_i - threshold) / threshold).
Default: ‘hard’.
- Returns:
weights – Per-sample weights in [0, 1]. Weights sum is used to normalize the loss.
- Return type:
array of shape (n_samples,)
- prosemble.core.curriculum.curriculum_threshold(iteration, max_iter, init_threshold=0.3, final_threshold=None, schedule='linear')[source]¶
Compute the difficulty threshold at a given training iteration.
The threshold increases over training: initially only easy samples are included, and harder samples are gradually added.
- Parameters:
iteration (int or array) – Current training iteration.
max_iter (int) – Maximum number of iterations.
init_threshold (float) – Initial threshold (fraction of max loss to include). At the start, only samples with loss < init_threshold * max_loss are included. Default: 0.3.
final_threshold (float, optional) – Final threshold. Default: None (uses a large value to include all samples by end of training).
schedule (str, {'linear', 'exponential', 'cosine'}) – How the threshold grows over training: - ‘linear’: lambda_t = init + (final - init) * t / T - ‘exponential’: lambda_t = init * (final / init) ^ (t / T) - ‘cosine’: cosine annealing from init to final. Default: ‘linear’.
- Returns:
threshold – Difficulty threshold at the current iteration.
- Return type:
- prosemble.core.curriculum.apply_curriculum_to_loss(per_sample_losses, iteration, max_iter, init_percentile=0.3, schedule='linear', mode='soft')[source]¶
Full curriculum pipeline: compute threshold and weighted loss.
Convenience function that combines threshold scheduling and weight computation. Use inside
_compute_lossto add curriculum learning to any model.- Parameters:
per_sample_losses (array of shape (n_samples,)) – Per-sample loss values.
iteration (int or array) – Current training step.
max_iter (int) – Total training steps.
init_percentile (float) – Initial fraction of samples to include (by loss quantile). Default: 0.3 (include the easiest 30% of samples initially).
schedule (str) – Threshold growth schedule. Default: ‘linear’.
mode (str) – Weighting mode. Default: ‘soft’.
- Returns:
weighted_loss – Curriculum-weighted mean loss.
- Return type:
scalar
Reject Option¶
Reject option with calibrated uncertainty for prototype classifiers.
Implements Chow’s optimal reject rule (1970) adapted for LVQ models. When the confidence in a prediction is below a threshold, the model abstains from classifying the sample (returns -1 or a reject label).
For LVQ, the confidence measure is the negative mu-ratio:
This is the natural confidence measure from the GLVQ loss. Values near 0 mean the sample is on the decision boundary (uncertain), values near 1 mean high confidence.
References
Chow, C. K. (1970). On optimum recognition error and reject tradeoff. IEEE Transactions on Information Theory.
Fischer, L., Hammer, B., & Wersing, H. (2015). Rejection strategies for learning vector quantization. Neurocomputing.
- class prosemble.core.reject.RejectOptionMixin[source]¶
Mixin providing reject option for any prototype-based classifier.
Adds
predict_with_rejection()andconfidence()methods to any supervised model that has fitted prototypes and prototype labels.The reject decision is based on the relative margin:
\[\text{confidence}(x) = \frac{d^-(x) - d^+(x)}{d^-(x) + d^+(x)}\]If confidence < threshold, the sample is rejected (label = -1).
- predict_with_rejection(X, threshold) array[source]¶
Predict with rejection. Returns -1 for rejected samples.
- optimal_threshold(X, y, cost_reject, cost_error) float[source]¶
Find the optimal rejection threshold minimizing total risk.
- confidence(X)[source]¶
Compute confidence scores based on the relative margin.
The confidence is the negative of the GLVQ mu-ratio: confidence(x) = (d_minus - d_plus) / (d_minus + d_plus)
Values range from -1 (maximally wrong) to +1 (maximally confident). Values near 0 indicate the sample is on the decision boundary.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input data.
- Returns:
scores – Confidence scores in [-1, 1].
- Return type:
array of shape (n_samples,)
- predict_with_rejection(X, threshold=0.0)[source]¶
Predict class labels with rejection option.
Samples with confidence below the threshold are assigned the reject label (-1) instead of a class prediction.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input data.
threshold (float) – Rejection threshold. Samples with confidence < threshold are rejected. Default: 0.0 (reject uncertain samples). Reasonable range: [0.0, 0.5].
- Returns:
labels – Predicted labels. Rejected samples have label -1.
- Return type:
array of shape (n_samples,)
- accuracy_coverage_curve(X, y, n_thresholds=50)[source]¶
Compute accuracy-coverage curve (reject curve).
Shows the trade-off between classification accuracy and coverage (1 - rejection_rate) as the threshold varies.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input data.
y (array-like of shape (n_samples,)) – True labels.
n_thresholds (int) – Number of threshold values to evaluate. Default: 50.
- Returns:
thresholds (array of shape (n_thresholds,)) – Threshold values evaluated.
accuracies (array of shape (n_thresholds,)) – Accuracy on non-rejected samples at each threshold.
coverages (array of shape (n_thresholds,)) – Fraction of non-rejected samples at each threshold.
- optimal_threshold(X, y, cost_reject=0.5, cost_error=1.0, n_thresholds=100)[source]¶
Find the optimal rejection threshold minimizing total risk.
- Minimizes Chow’s risk:
Risk = P(error|accepted) * coverage + cost_reject * (1 - coverage)
- Parameters:
X (array-like of shape (n_samples, n_features))
y (array-like of shape (n_samples,)) – True labels.
cost_reject (float) – Cost of rejecting a sample (relative to cost of error). Default: 0.5 (rejection costs half of a misclassification).
cost_error (float) – Cost of a misclassification. Default: 1.0.
n_thresholds (int) – Number of thresholds to evaluate. Default: 100.
- Returns:
optimal_threshold – The threshold that minimizes total risk.
- Return type:
Geodesic Interpolation¶
Geodesic interpolation and boundary visualization for Riemannian models.
Computes geodesic paths between prototypes on Riemannian manifolds for decision boundary visualization and model interpretation.
On curved manifolds, the decision boundary (locus where d(x, w_i) = d(x, w_j)) is not a hyperplane — it’s a curved surface. Geodesic interpolation allows us to: 1. Visualize the path between prototypes along the manifold. 2. Find the approximate decision boundary point along geodesics. 3. Compute prototype midpoints respecting manifold geometry.
References
Absil, P.-A., Mahony, R., & Sepulchre, R. (2008). Optimization Algorithms on Matrix Manifolds. Princeton University Press.
- prosemble.core.geodesic.geodesic_interpolation(manifold, point_a, point_b, n_points=50)[source]¶
Compute a geodesic path between two points on a manifold.
Uses the exponential map to trace the shortest path (geodesic) between two manifold points:
\[\gamma(t) = \text{Exp}_{w_a}(t \cdot \text{Log}_{w_a}(w_b)), \quad t \in [0, 1]\]- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance with
exp_mapandlog_map.point_a (array) – Starting point on the manifold.
point_b (array) – End point on the manifold.
n_points (int) – Number of interpolation points along the geodesic. Default: 50.
- Returns:
path – Points along the geodesic from point_a to point_b. Each point lies on the manifold.
- Return type:
array of shape (n_points, …)
- prosemble.core.geodesic.geodesic_midpoint(manifold, point_a, point_b)[source]¶
Compute the geodesic midpoint between two manifold points.
The midpoint is at t = 0.5 along the geodesic:
\[m = \text{Exp}_{w_a}(0.5 \cdot \text{Log}_{w_a}(w_b))\]- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
point_a (array) – First manifold point.
point_b (array) – Second manifold point.
- Returns:
midpoint – Geodesic midpoint, guaranteed to lie on the manifold.
- Return type:
array
- prosemble.core.geodesic.decision_boundary_point(manifold, proto_a, proto_b, n_search=100)[source]¶
Find the decision boundary point along the geodesic between prototypes.
On curved manifolds, the equidistant point (where d(x, w_a) = d(x, w_b)) is not necessarily at t = 0.5. This function searches along the geodesic for the point where distances to both prototypes are equal.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
proto_a (array) – First prototype (on manifold).
proto_b (array) – Second prototype (on manifold).
n_search (int) – Number of candidate points to evaluate. Default: 100.
- Returns:
boundary_point (array) – Point on the geodesic where d(point, proto_a) = d(point, proto_b).
t_boundary (float) – Parameter value t in [0, 1] where the boundary lies. t = 0.5 for symmetric manifolds, may differ on curved spaces.
- prosemble.core.geodesic.prototype_geodesic_distances(manifold, prototypes, proto_labels)[source]¶
Compute pairwise geodesic distances between prototypes.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
prototypes (array of shape (n_prototypes, ...)) – Prototype points on the manifold (flattened).
proto_labels (array of shape (n_prototypes,)) – Class labels for prototypes.
- Returns:
distances – Pairwise geodesic distance matrix.
- Return type:
array of shape (n_prototypes, n_prototypes)
- prosemble.core.geodesic.inter_class_geodesics(manifold, prototypes, proto_labels, n_points=50)[source]¶
Compute geodesic paths between all inter-class prototype pairs.
Useful for visualizing decision boundaries: the boundary between two classes lies somewhere along the geodesic connecting their closest prototypes.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
prototypes (array of shape (n_prototypes, ...)) – Prototype positions on manifold.
proto_labels (array of shape (n_prototypes,)) – Class labels.
n_points (int) – Points per geodesic path. Default: 50.
- Returns:
geodesics – Each dict contains: - ‘path’: array of shape (n_points, …) — geodesic path - ‘proto_a_idx’: int — index of first prototype - ‘proto_b_idx’: int — index of second prototype - ‘class_a’: int — class of first prototype - ‘class_b’: int — class of second prototype - ‘boundary_t’: float — approximate boundary location
- Return type: