Core API Reference¶
Core modules provide the building blocks used by all models: distance functions, loss functions, initializers, and utilities.
Distance Functions¶
JAX-based distance functions for Prosemble.
This module provides GPU-accelerated, vectorized distance computations using JAX. All functions are JIT-compiled for maximum performance.
Mathematical Background¶
Distance metrics are fundamental to prototype-based learning algorithms. This implementation focuses on: 1. Batch/matrix operations (no Python loops) 2. GPU compatibility 3. JIT compilation for speed 4. Numerical stability
Author: Prosemble Contributors License: MIT
- prosemble.core.distance.euclidean_distance_matrix(X, Y)[source]¶
Compute pairwise Euclidean distances between rows of X and Y.
Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).
Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)
- Return type:
D
- Complexity:
Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix
Example
>>> X = jnp.array([[0, 0], [1, 1], [2, 2]]) >>> Y = jnp.array([[0, 0], [3, 3]]) >>> D = euclidean_distance_matrix(X, Y) >>> D.shape (3, 2) >>> D[0, 0] # Distance from X[0] to Y[0] 0.0 >>> D[2, 1] # Distance from X[2] to Y[1] 1.414...
Notes
Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives
GPU-compatible: All operations are JAX primitives
JIT-compiled: First call compiles, subsequent calls are fast
- prosemble.core.distance.squared_euclidean_distance_matrix(X, Y)[source]¶
Compute pairwise squared Euclidean distances.
More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.
Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nm)
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D_sq = squared_euclidean_distance_matrix(X, Y) >>> D_sq[1, 1] # Squared distance from [1,1] to [2,2] 2.0
Notes
Preferred over euclidean when squared distances are sufficient
Many algorithms (FCM, PCM) use squared distances directly
More numerically stable than squaring euclidean distances
- prosemble.core.distance.manhattan_distance_matrix(X, Y)[source]¶
Compute pairwise Manhattan (L1) distances.
Formula: \(D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|\)
- Parameters:
- Returns:
Array of shape (n, m) where D[i,j] is Manhattan distance
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nmd) - intermediate broadcasting
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D = manhattan_distance_matrix(X, Y) >>> D[1, 1] # Manhattan distance from [1,1] to [2,2] 2.0
- Implementation:
Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d) Then sums absolute differences along feature dimension.
Notes
Also known as “taxicab” or “city block” distance
More robust to outliers than Euclidean distance
Natural for sparse/binary features
- prosemble.core.distance.lpnorm_distance_matrix(X, Y, p)[source]¶
Compute pairwise L-p norm distances.
Formula: \(D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}\)
- Special Cases:
p = 1: Manhattan distance p = 2: Euclidean distance p = \(\infty\): Chebyshev distance (max absolute difference)
- Parameters:
- Returns:
Array of shape (n, m) where D[i,j] is L-p distance
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nmd)
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [3, 4]]) >>> D = lpnorm_distance_matrix(X, Y, p=2) # Euclidean >>> D = lpnorm_distance_matrix(X, Y, p=1) # Manhattan >>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf) # Chebyshev
Notes
For p=1, use manhattan_distance_matrix for better performance
For p=2, use euclidean_distance_matrix for better performance
For p=inf, computes
max(|x - y|)
- prosemble.core.distance.omega_distance_matrix(X, Y, omega)[source]¶
Compute distances in projected space using projection matrix Omega.
Formula: \(D_{ij} = \|X_i \Omega - Y_j \Omega\|^2\) where \(\Omega\) is a projection matrix that transforms the feature space.
- Parameters:
- Returns:
Array of shape (n, m) with squared distances in projected space
- Return type:
D
- Complexity:
Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk) Space: O(nk + mk + nm)
- Use Cases:
Dimensionality reduction for distance computation
Learning relevance of features (omega learned from data)
Mahalanobis-like distances (when \(\Omega = L\) where \(\Sigma = LL^T\))
Example
>>> X = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> Y = jnp.array([[0, 0, 0], [1, 1, 1]]) >>> omega = jnp.array([[1, 0], [0, 1], [0, 0]]) # Project to first 2 dims >>> D = omega_distance_matrix(X, Y, omega) >>> D.shape (2, 2)
Notes
When omega is identity, reduces to squared Euclidean distance
When omega is learned, enables adaptive distance metrics
Used in GLVQ (Generalized Learning Vector Quantization)
- prosemble.core.distance.lomega_distance_matrix(X, Y, omegas)[source]¶
Compute distances using multiple projection matrices (Local Omega).
Formula: \(D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2\) where \(\Omega_p\) are multiple projection matrices (one per prototype or cluster).
- Parameters:
X (Array | ndarray | bool | number) – Array of shape (n, d) - data points
Y (Array | ndarray | bool | number) – Array of shape (m, d) - prototypes/centroids
omegas (Array | ndarray | bool | number) – Array of shape (m, d, k) - m projection matrices of size (d, k) Each Y[j] has its own projection matrix omegas[j]
- Returns:
Array of shape (n, m) with aggregated projected distances
- Return type:
D
- Complexity:
Time: O(nmdk) Space: O(nmk)
- Use Cases:
Local relevance learning (each prototype has its own metric)
Adaptive distance metrics in GMLVQ
Cluster-specific feature weighting
Example
>>> n, m, d, k = 10, 3, 5, 2 >>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d)) >>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d)) >>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k)) >>> D = lomega_distance_matrix(X, Y, omegas) >>> D.shape (10, 3)
- Implementation:
Uses einsum for efficient tensor contraction: 1. Project X through each omega: X @ omegas[j] for all j 2. Extract diagonal for Y projections (each Y[j] uses omegas[j]) 3. Compute squared differences and sum
Notes
Generalizes omega_distance to local (per-prototype) metrics
More flexible but computationally expensive
Enables learning which features matter for each cluster
- prosemble.core.distance.tangent_distance_matrix(X, Y, omegas)[source]¶
Compute pairwise localized tangent distances.
Each prototype j has an orthogonal subspace basis \(\Omega_j\) of shape (d, s). The tangent distance projects out the subspace directions:
\[d(x, w_j) = \|(I - \Omega_j \Omega_j^T)(x - w_j)\|^2\]- Parameters:
X (array of shape (n, d)) – Data points.
Y (array of shape (m, d)) – Prototypes.
omegas (array of shape (m, d, s)) – Orthogonal subspace bases per prototype, where s is the subspace dimension.
- Returns:
D – Squared tangent distances.
- Return type:
array of shape (n, m)
Notes
Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent distances in generalized learning vector quantization.
- prosemble.core.distance.gaussian_kernel_matrix(X, Y, sigma)[source]¶
Compute Gaussian (RBF) kernel matrix.
Formula: \(K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))\)
The Gaussian kernel maps data to infinite-dimensional Hilbert space, enabling non-linear clustering and classification.
- Args:
X: Array of shape (n, d) Y: Array of shape (m, d) sigma: Bandwidth parameter (\(\sigma > 0\))
- Returns:
- K: Array of shape (n, m) where \(K_{ij} \in [0, 1]\).
\(K_{ij} = 1\) when \(X_i = Y_j\); \(K_{ij} o 0\) as \(\|X_i - Y_j\| o \infty\)
- Complexity:
Time: O(nmd) Space: O(nm)
- Properties:
K is positive semi-definite (valid kernel)
K is symmetric if X = Y
K[i,i] = 1 (self-similarity)
- Example:
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> K = gaussian_kernel_matrix(X, Y, sigma=1.0) >>> K[0, 0] # Self-similarity 1.0 >>> K[0, 1] < K[0, 0] # Decreases with distance True
- Kernel Trick:
For feature map \(\phi\) mapping to infinite-dimensional Hilbert space, :math:`K(x, y) = langlephi(x), phi(y)
- angle`. Kernel distance:
\(\|\phi(x) - \phi(y)\|^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y)\) for normalized kernel.
- Use Cases:
Kernel Fuzzy C-Means (KFCM)
Kernel Possibilistic C-Means (KPCM)
Support Vector Machines (SVM)
Gaussian Processes
- Hyperparameter Tuning:
Small \(\sigma\): Tight clusters, high sensitivity to noise
Large \(\sigma\): Smooth clusters, may underfit
Rule of thumb: \(\sigma pprox ext{median}( ext{pairwise\_distances}) / \sqrt{2 \cdot n_ ext{clusters}}\)
- Notes:
sigma is bandwidth, NOT variance (variance = \(\sigma^2\))
For numerical stability, we use maximum() to ensure non-negative
JIT-compiled for GPU acceleration
- prosemble.core.distance.polynomial_kernel_matrix(X, Y, degree=3, coef0=1.0)[source]¶
Compute polynomial kernel matrix.
Formula: \(K_{ij} = (X_i^T Y_j + c)^d\) where d is degree and c is coef0.
- Parameters:
- Returns:
Array of shape (n, m) with polynomial kernel values
- Return type:
K
Example
>>> X = jnp.array([[1, 2], [3, 4]]) >>> Y = jnp.array([[1, 0], [0, 1]]) >>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0)
Notes
degree=1, coef0=0: Linear kernel (dot product)
Higher degree: More complex decision boundaries
coef0: Influences importance of lower vs higher order terms
- prosemble.core.distance.euclidean_distance(x, y)[source]¶
Euclidean distance between two vectors.
- Parameters:
- Returns:
Scalar distance
- Return type:
Example
>>> x = jnp.array([0, 0, 0]) >>> y = jnp.array([1, 1, 1]) >>> d = euclidean_distance(x, y) >>> d Array(1.732..., dtype=float32)
- prosemble.core.distance.squared_euclidean_distance(x, y)[source]¶
Squared Euclidean distance between two vectors.
- prosemble.core.distance.manhattan_distance(x, y)[source]¶
Manhattan (L1) distance between two vectors.
- prosemble.core.distance.omega_distance(x, y, omega)[source]¶
Omega (projection-based) distance between two vectors.
Computes \(\| ext{diff} \cdot \Omega\|^2\) where :math:` ext{diff} = x - y`.
- Parameters:
- Returns:
Scalar squared distance in projected space
- Return type:
- prosemble.core.distance.lomega_distance(X, Y, omegas)[source]¶
Local omega distance with per-prototype projection matrices.
- Parameters:
- Returns:
Distance matrix of shape (n, m)
- Return type:
- prosemble.core.distance.estimate_sigma(X, percentile=50.0)[source]¶
Estimate sigma for Gaussian kernel using pairwise distances.
Strategy: Use median (or other percentile) of pairwise distances.
- Parameters:
- Returns:
Estimated bandwidth parameter
- Return type:
sigma
Example
>>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10)) >>> sigma = estimate_sigma(X, percentile=50)
Notes
Heuristic: \(\sigma = ext{median\_distance} / \sqrt{2 \cdot n_ ext{clusters}}\)
For large datasets, use subsample to avoid O(n²) computation
- prosemble.core.distance.safe_divide(numerator, denominator, epsilon=1e-10)[source]¶
Safe division avoiding division by zero.
- Parameters:
- Returns:
numerator / (denominator + epsilon)
- Return type:
Example
>>> x = jnp.array([1.0, 2.0, 3.0]) >>> y = jnp.array([2.0, 0.0, 1.0]) >>> safe_divide(x, y) Array([0.5, 2e+09, 3.0], dtype=float32) # Avoids inf
- prosemble.core.distance.batch_squared_euclidean(X, Y)¶
Compute pairwise squared Euclidean distances.
More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.
Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)
- Return type:
D
- Complexity:
Time: O(nmd) Space: O(nm)
Example
>>> X = jnp.array([[0, 0], [1, 1]]) >>> Y = jnp.array([[0, 0], [2, 2]]) >>> D_sq = squared_euclidean_distance_matrix(X, Y) >>> D_sq[1, 1] # Squared distance from [1,1] to [2,2] 2.0
Notes
Preferred over euclidean when squared distances are sufficient
Many algorithms (FCM, PCM) use squared distances directly
More numerically stable than squaring euclidean distances
- prosemble.core.distance.batch_euclidean(X, Y)¶
Compute pairwise Euclidean distances between rows of X and Y.
Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).
Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)
- Parameters:
- Returns:
Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)
- Return type:
D
- Complexity:
Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix
Example
>>> X = jnp.array([[0, 0], [1, 1], [2, 2]]) >>> Y = jnp.array([[0, 0], [3, 3]]) >>> D = euclidean_distance_matrix(X, Y) >>> D.shape (3, 2) >>> D[0, 0] # Distance from X[0] to Y[0] 0.0 >>> D[2, 1] # Distance from X[2] to Y[1] 1.414...
Notes
Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives
GPU-compatible: All operations are JAX primitives
JIT-compiled: First call compiles, subsequent calls are fast
Similarities¶
Similarity functions for prototype-based learning.
Similarities are the dual of distances: higher values indicate closer/more similar points.
- prosemble.core.similarities.gaussian_similarity(distances_sq, variance=1.0)[source]¶
Convert squared distances to Gaussian similarities.
\(s(d) = \exp(-d^2 / (2 \cdot ext{variance}))\)
- Parameters:
distances_sq (array) – Squared distances.
variance (float) – Variance (sigma^2) of the Gaussian.
- Returns:
Similarity values in (0, 1].
- Return type:
array
- prosemble.core.similarities.cosine_similarity_matrix(X, Y)[source]¶
Pairwise cosine similarity between rows of X and Y.
:math:`cos(x, y) =
rac{x cdot y}{|x| cdot |y|}`
X : array of shape (n, d) Y : array of shape (m, d)
- array of shape (n, m)
Cosine similarities in [-1, 1].
- prosemble.core.similarities.euclidean_similarity(X, Y, variance=1.0)[source]¶
Pairwise Euclidean similarity (Gaussian of Euclidean distance).
- Parameters:
X (array of shape (n, d))
Y (array of shape (m, d))
variance (float) – Variance of the Gaussian kernel.
- Returns:
Similarity values in (0, 1].
- Return type:
array of shape (n, m)
- prosemble.core.similarities.rank_scaled_gaussian(distances, lambd=1.0)[source]¶
Rank-scaled Gaussian similarity.
Combines distance magnitude with rank ordering: closer prototypes (lower rank) receive a stronger signal, while farther ones are exponentially suppressed.
\[s(d, r) = \exp(-\exp(-r / \lambda) \cdot d)\]where r is the rank of each distance (0 = closest).
- Parameters:
distances (array of shape (n, m)) – Distance matrix (non-negative).
lambd (float) – Rank decay parameter. Larger values give more uniform weighting across ranks; smaller values concentrate on nearest neighbours.
- Returns:
Rank-scaled similarity values in (0, 1].
- Return type:
array of shape (n, m)
Notes
Used in Probabilistic LVQ (PLVQ) as a conditional distribution P(x|prototype).
Kernel Functions¶
JAX-based kernel functions for kernel clustering methods.
This module provides GPU-accelerated kernel computations using JAX.
- prosemble.core.kernel.gaussian_kernel(x, y, sigma)[source]¶
Compute Gaussian (RBF) kernel between two vectors.
\(K(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))\)
- prosemble.core.kernel.batch_gaussian_kernel(X, Y, sigma)[source]¶
Compute Gaussian kernel between two sets of vectors.
- Parameters:
- Returns:
Kernel matrix, shape (n_samples, m_samples) K[i, j] = K(X[i], Y[j])
- Return type:
- prosemble.core.kernel.kernel_distance_squared(X, Y, sigma)[source]¶
Compute squared distance in feature space.
For Gaussian kernel: \(\|\phi(x) - \phi(y)\|^2 = K(x,x) + K(y,y) - 2K(x,y) = 2(1 - K(x,y))\), since \(K(x,x) = 1\).
- Parameters:
- Returns:
Squared distances in feature space, shape (n_samples, m_samples)
- Return type:
- prosemble.core.kernel.kernel_distance(X, Y, sigma)[source]¶
Compute distance in feature space.
- Parameters:
- Returns:
Distances in feature space, shape (n_samples, m_samples)
- Return type:
- prosemble.core.kernel.kernel_distance_squared_per_proto(X, W, sigmas)[source]¶
Squared kernel distance with per-prototype bandwidth.
\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\|x - w_k\|^2}{2\sigma_k^2}\right)\right)\]Each prototype \(w_k\) has its own bandwidth \(\sigma_k\).
- Parameters:
- Returns:
Squared distances in feature space, shape (n_samples, n_prototypes).
- Return type:
References
- prosemble.core.kernel.kernel_distance_squared_relevance(X, W, sigmas, relevances)[source]¶
Squared kernel distance with relevance weighting and per-prototype bandwidth.
\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2}\right)\right)\]- Parameters:
X (Array | ndarray | bool | number) – Data matrix, shape (n_samples, n_features).
W (Array | ndarray | bool | number) – Prototype matrix, shape (n_prototypes, n_features).
sigmas (Array | ndarray | bool | number) – Per-prototype bandwidths, shape (n_prototypes,).
relevances (Array | ndarray | bool | number) – Normalized relevance weights, shape (n_features,).
- Returns:
Squared distances in feature space, shape (n_samples, n_prototypes).
- Return type:
References
[1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing.
- prosemble.core.kernel.exponential_kernel_distance_squared(X, W, omega_hat)[source]¶
Squared distance in exponential kernel feature space.
Uses the exponential kernel \(\kappa_{\exp}(v, w, \hat\Lambda) = \exp(v^T \hat\Lambda w)\) where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).
\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]Note: \(\kappa(v, v) \neq 1\) for the exponential kernel, so the full three-term formula is required (not the 2(1-K) simplification).
- Parameters:
X (Array | ndarray | bool | number) – Data matrix, shape (n_samples, n_features).
W (Array | ndarray | bool | number) – Prototype matrix, shape (n_prototypes, n_features).
omega_hat (Array | ndarray | bool | number) – Transformation matrix, shape (n_features, latent_dim). The kernel matrix is \(\hat\Lambda = \hat\Omega \hat\Omega^T\).
- Returns:
Squared distances in feature space, shape (n_samples, n_prototypes).
- Return type:
References
[1] Villmann, T., Haase, S., & Kaden, M. (2015). Kernelized vector quantization in gradient-descent learning. Neurocomputing.
Loss Functions¶
Loss functions for prototype-based learning.
All loss functions are differentiable by jax.grad and JIT-compatible. They use jnp.where masking (not boolean indexing) for d+/d- extraction.
- prosemble.core.losses.glvq_loss(distances, target_labels, prototype_labels, margin=0.0)[source]¶
Generalized LVQ loss.
:math:`mu_i =
rac{d^+_i - d^-_i}{d^+_i + d^-_i}`
distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) margin : float
Margin added to mu before transfer.
- scalar
Mean loss over samples.
- prosemble.core.losses.glvq_loss_with_transfer(distances, target_labels, prototype_labels, transfer_fn=<PjitFunction of <function identity>>, margin=0.0, beta=10.0)[source]¶
GLVQ loss with configurable transfer function.
:math:` ext{loss} = ext{mean}(f(mu + ext{margin}, eta))`
- prosemble.core.losses.lvq1_loss(distances, target_labels, prototype_labels)[source]¶
LVQ1 loss: d+ when correct, -d- when wrong.
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
- Return type:
scalar
- prosemble.core.losses.lvq21_loss(distances, target_labels, prototype_labels)[source]¶
LVQ2.1 loss: d+ - d- (unnormalized).
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
- Return type:
scalar
- prosemble.core.losses.nllr_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]¶
Negative Log-Likelihood Ratio loss (for SLVQ).
:math:` ext{loss} = -log(P( ext{correct}) / P( ext{wrong}))`
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
sigma (float) – Bandwidth of Gaussian mixture.
- Return type:
scalar
- prosemble.core.losses.rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]¶
Robust Soft LVQ loss (for RSLVQ).
:math:` ext{loss} = -log(P( ext{correct}) / P( ext{all}))`
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
sigma (float)
- Return type:
scalar
- prosemble.core.losses.ng_rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0, gamma=1.0)[source]¶
RSLVQ loss with Neural Gas rank-based neighborhood cooperation.
Combines Gaussian mixture prototype probabilities with NG rank weights to create a neighborhood-cooperative probabilistic assignment.
Gaussian: \(p(k|x) = \exp(-d_k / 2\sigma^2) / \sum_j \exp(-d_j / 2\sigma^2)\)
NG weights: \(h_k = \exp(- ext{rank}_k / \gamma) / \sum_j \exp(- ext{rank}_j / \gamma)\)
Combined: \(w_k = p(k|x) \cdot h_k / \sum_j p(j|x) \cdot h_j\)
The loss is \(-\log(\sum_{k \in ext{correct}} w_k)\).
- Parameters:
distances (array of shape (n, p)) – Squared distances from samples to prototypes.
target_labels (array of shape (n,)) – True class labels for samples.
prototype_labels (array of shape (p,)) – Class labels assigned to prototypes.
sigma (float) – Bandwidth of Gaussian mixture.
gamma (float) – Neural Gas neighborhood range.
- Returns:
Mean negative log-likelihood with NG cooperation.
- Return type:
scalar
- prosemble.core.losses.cross_entropy_lvq_loss(distances, target_labels, prototype_labels, n_classes)[source]¶
Cross-entropy LVQ loss (for CELVQ).
Min distances per class via masking
Negate to get logits (closer = higher)
Cross-entropy against true labels
- Parameters:
distances (array of shape (n, p))
target_labels (array of shape (n,))
prototype_labels (array of shape (p,))
n_classes (int)
- Return type:
scalar
- prosemble.core.losses.margin_loss(y_pred, y_true_one_hot, margin=0.3)[source]¶
Margin loss for CBC.
:math:` ext{loss} = ext{ReLU}(max( ext{wrong}) - ext{correct} + ext{margin})`
- Parameters:
y_pred (array of shape (n, n_classes)) – Predicted class probabilities.
y_true_one_hot (array of shape (n, n_classes)) – One-hot encoded true labels.
margin (float)
- Return type:
scalar
- prosemble.core.losses.neural_gas_energy(distances, lam)[source]¶
Neural Gas energy function.
\[E = \sum_k h( ext{rank}_k, \lambda) \cdot d(x, w_k), \quad h( ext{rank}, \lambda) = \exp(- ext{rank} / \lambda)\]- Parameters:
distances (array of shape (n, p))
lam (float) – Neighborhood range parameter.
- Return type:
scalar
Activations¶
Transfer/activation functions for prototype-based learning.
These functions are used to shape the GLVQ loss (mu values) before summation, controlling the optimization landscape.
- prosemble.core.activations.identity(x, beta=0.0)[source]¶
Identity activation (passthrough).
- Parameters:
x (array) – Input values.
beta (float) – Ignored. Present for API consistency.
- Returns:
Same as input.
- Return type:
array
- prosemble.core.activations.sigmoid_beta(x, beta=10.0)[source]¶
Sigmoid activation with steepness parameter.
f(x) = 1 / (1 + exp(-beta * x))
- Parameters:
x (array) – Input values.
beta (float) – Steepness parameter. Higher values give sharper transition.
- Returns:
Sigmoid-transformed values in (0, 1).
- Return type:
array
Competitions¶
Competition mechanisms for prototype-based classification.
These functions determine class predictions from distance matrices and prototype labels using different strategies.
- prosemble.core.competitions.wtac(distances, prototype_labels)[source]¶
Winner-Takes-All Competition.
Assigns each sample the label of the closest prototype.
- Parameters:
distances (array of shape (n_samples, n_prototypes)) – Distance matrix.
prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.
- Returns:
Predicted labels.
- Return type:
array of shape (n_samples,)
- prosemble.core.competitions.knnc(distances, prototype_labels, k=1, n_classes=None)[source]¶
K-Nearest Neighbors Competition.
Assigns each sample the majority label among k closest prototypes.
- Parameters:
- Returns:
Predicted labels.
- Return type:
array of shape (n_samples,)
- prosemble.core.competitions.cbcc(detections, reasonings)[source]¶
Classification-By-Components Competition.
Computes class probability distributions using component detections and reasoning matrices.
- Parameters:
detections (array of shape (n_samples, n_components)) – Similarity/detection scores for each component.
reasonings (array of shape (n_components, n_classes, 2)) – Reasoning matrices. Last dim: [positive, negative_raw].
- Returns:
Class probability distributions.
- Return type:
array of shape (n_samples, n_classes)
Initializers¶
Prototype and parameter initializers for prototype-based learning.
These functions generate initial prototypes, labels, and transformation matrices for supervised and unsupervised models.
- prosemble.core.initializers.stratified_selection_init(X, y, n_per_class, key)[source]¶
Initialize prototypes by randomly selecting samples per class.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
y (array of shape (n_samples,)) – Training labels.
n_per_class (int, list, or dict) – Number of prototypes per class. - int: same count for all classes. - list: index i gives the count for class i, e.g.
[2, 2, 1]. - dict: maps class label to count, e.g.{0: 2, 1: 2, 2: 1}.key (jax.random.PRNGKey) – Random key.
- Returns:
prototypes (array of shape (n_prototypes, n_features))
prototype_labels (array of shape (n_prototypes,))
- prosemble.core.initializers.stratified_mean_init(X, y)[source]¶
Initialize prototypes at the mean of each class.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
y (array of shape (n_samples,)) – Training labels.
- Returns:
prototypes (array of shape (n_classes, n_features))
prototype_labels (array of shape (n_classes,))
- prosemble.core.initializers.random_normal_init(n_prototypes, n_features, key, mean=0.0, std=1.0)[source]¶
Initialize prototypes from a normal distribution.
- prosemble.core.initializers.identity_omega_init(n_features, n_dims=None)[source]¶
Initialize omega as an identity (or truncated identity) matrix.
- prosemble.core.initializers.random_omega_init(n_features, n_dims, key)[source]¶
Initialize omega as a random orthogonal matrix via QR decomposition.
- prosemble.core.initializers.uniform_init(n_prototypes, n_features, key, low=0.0, high=1.0)[source]¶
Initialize prototypes from a uniform distribution.
- prosemble.core.initializers.zeros_init(n_prototypes, n_features)[source]¶
Initialize prototypes as zeros.
Useful for checkpoint loading where shapes must be pre-allocated before restoring saved values.
- prosemble.core.initializers.ones_init(n_prototypes, n_features)[source]¶
Initialize prototypes as ones.
- prosemble.core.initializers.fill_value_init(n_prototypes, n_features, value=0.0)[source]¶
Initialize prototypes filled with a constant value.
- prosemble.core.initializers.selection_init(X, n_prototypes, key)[source]¶
Initialize prototypes by uniformly sampling from data (classless).
Suitable for unsupervised models like Neural Gas, SOM, and fuzzy clustering where no labels are available.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
n_prototypes (int) – Number of prototypes to select.
key (jax.random.PRNGKey)
- Return type:
array of shape (n_prototypes, n_features)
- prosemble.core.initializers.mean_init(X, n_prototypes)[source]¶
Initialize all prototypes at the data mean (classless).
Suitable for unsupervised models. All prototypes start at the global mean and diverge during training.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
n_prototypes (int) – Number of prototypes.
- Return type:
array of shape (n_prototypes, n_features)
- prosemble.core.initializers.literal_init(values)[source]¶
Initialize prototypes from literal values.
Used for warm-starting from another model’s prototypes or from user-provided values.
- Parameters:
values (array-like of shape (n_prototypes, n_features)) – Literal prototype values.
- Return type:
array of shape (n_prototypes, n_features)
- prosemble.core.initializers.stratified_noise_init(X, y, n_per_class, key, noise_std=0.1)[source]¶
Initialize prototypes by selecting samples per class and adding noise.
Combines stratified selection with Gaussian noise injection for diverse initial prototypes.
- Parameters:
- Returns:
prototypes (array of shape (n_prototypes, n_features))
prototype_labels (array of shape (n_prototypes,))
- prosemble.core.initializers.pca_omega_init(X, n_dims)[source]¶
Initialize omega using PCA directions from training data.
The top-n_dims principal components become the columns of omega, providing a data-driven initialization for metric learning models.
- Parameters:
X (array of shape (n_samples, n_features)) – Training data.
n_dims (int) – Number of principal components (projection dimensionality).
- Returns:
omega
- Return type:
array of shape (n_features, n_dims)
- prosemble.core.initializers.class_conditional_mean_init(X, y, n_per_class)[source]¶
Initialize prototypes at class means, replicated per n_per_class.
When n_per_class > 1, each class gets multiple prototypes all initialized at the class mean (they will diverge during training).
- prosemble.core.initializers.random_reasonings_init(n_components, n_classes, key)[source]¶
Initialize CBC reasoning matrices randomly.
Pooling¶
Stratified pooling operations for prototype-based learning.
These functions aggregate per-prototype distances into per-class distances, grouping by prototype label. Essential for GLVQ and CELVQ.
- prosemble.core.pooling.stratified_min_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class minimum distance pooling.
For each sample and each class, returns the minimum distance to any prototype of that class.
- Parameters:
distances (array of shape (n_samples, n_prototypes)) – Distance matrix.
prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.
n_classes (int) – Number of classes.
- Returns:
Minimum distance to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
- prosemble.core.pooling.stratified_sum_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class sum distance pooling.
- Parameters:
distances (array of shape (n_samples, n_prototypes))
prototype_labels (array of shape (n_prototypes,))
n_classes (int)
- Returns:
Sum of distances to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
- prosemble.core.pooling.stratified_max_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class maximum distance pooling.
- Parameters:
distances (array of shape (n_samples, n_prototypes))
prototype_labels (array of shape (n_prototypes,))
n_classes (int)
- Returns:
Maximum distance to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
- prosemble.core.pooling.stratified_prod_pooling(distances, prototype_labels, n_classes)[source]¶
Per-class product distance pooling.
Uses log-sum-exp for numerical stability.
- Parameters:
distances (array of shape (n_samples, n_prototypes))
prototype_labels (array of shape (n_prototypes,))
n_classes (int)
- Returns:
Product of distances to each class for each sample.
- Return type:
array of shape (n_samples, n_classes)
Quantization¶
Shared mixins for model base classes.
- class prosemble.core.quantization.MetadataCollectorMixin[source]¶
Mixin that auto-collects _hyperparams and _fitted_array_names from MRO.
Any base class that declares per-class
_hyperparamsand_fitted_array_namestuples can inherit this mixin to get_all_hyperparamsand_all_fitted_array_namesaggregated automatically across the entire class hierarchy.
- class prosemble.core.quantization.QuantizationMixin[source]¶
Mixin for quantizing/dequantizing fitted model parameters.
Supports float16, bfloat16, and int8 (with per-tensor scale factors).
Subclasses override
_get_quantizable_attrsto declare which fitted attributes are eligible for quantization.
Utilities¶
JAX utility functions for data preprocessing and manipulation.
Replaces common sklearn/numpy operations with JAX-native implementations.
- prosemble.core.utils.train_test_split_jax(X, y=None, test_size=0.2, random_seed=42)[source]¶
JAX-native train/test split.
- Parameters:
- Returns:
Split arrays. If y is provided, returns 4 arrays, else 2.
- Return type:
X_train, X_test[, y_train, y_test]
Examples
>>> X = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> y = jnp.array([0, 1, 0, 1]) >>> X_train, X_test, y_train, y_test = train_test_split_jax(X, y, test_size=0.5)
- prosemble.core.utils.standardize(X, mean=None, std=None)[source]¶
Standardize features (zero mean, unit variance).
- Parameters:
X (array-like of shape (n_samples, n_features)) – Data to standardize
mean (array-like of shape (n_features,), optional) – Pre-computed mean (for test data)
std (array-like of shape (n_features,), optional) – Pre-computed std (for test data)
- Returns:
X_scaled (array-like) – Standardized data
mean (array-like) – Mean used for scaling
std (array-like) – Std used for scaling
- Return type:
tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]
Examples
>>> X_train = jnp.array([[1, 2], [3, 4], [5, 6]]) >>> X_scaled, mean, std = standardize(X_train) >>> # For test data >>> X_test_scaled, _, _ = standardize(X_test, mean=mean, std=std)
- prosemble.core.utils.min_max_scale(X, min_val=None, max_val=None)[source]¶
Scale features to [0, 1] range.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Data to scale
min_val (array-like of shape (n_features,), optional) – Pre-computed min (for test data)
max_val (array-like of shape (n_features,), optional) – Pre-computed max (for test data)
- Returns:
X_scaled (array-like) – Scaled data
min_val (array-like) – Min values used
max_val (array-like) – Max values used
- Return type:
tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]
- prosemble.core.utils.pca_jax(X, n_components=2)[source]¶
Principal Component Analysis using JAX.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Data matrix
n_components (int, default=2) – Number of principal components
- Returns:
X_transformed (array-like of shape (n_samples, n_components)) – Transformed data
components (array-like of shape (n_components, n_features)) – Principal components
- Return type:
tuple[Array | ndarray | bool | number, Array | ndarray | bool | number]
Examples
>>> X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> X_pca, components = pca_jax(X, n_components=2)
- prosemble.core.utils.accuracy_score_jax(y_true, y_pred)[source]¶
Compute classification accuracy.
- Parameters:
y_true (array-like of shape (n_samples,)) – True labels
y_pred (array-like of shape (n_samples,)) – Predicted labels
- Returns:
accuracy – Accuracy score
- Return type:
Examples
>>> y_true = jnp.array([0, 1, 1, 0]) >>> y_pred = jnp.array([0, 1, 0, 0]) >>> accuracy_score_jax(y_true, y_pred) 0.75
- prosemble.core.utils.confusion_matrix_jax(y_true, y_pred, n_classes)[source]¶
Compute confusion matrix.
- Parameters:
y_true (array-like of shape (n_samples,)) – True labels
y_pred (array-like of shape (n_samples,)) – Predicted labels
n_classes (int) – Number of classes
- Returns:
conf_matrix – Confusion matrix
- Return type:
array-like of shape (n_classes, n_classes)
Examples
>>> y_true = jnp.array([0, 1, 2, 0, 1, 2]) >>> y_pred = jnp.array([0, 2, 2, 0, 0, 1]) >>> confusion_matrix_jax(y_true, y_pred, n_classes=3)
- prosemble.core.utils.shuffle_jax(X, y=None, random_seed=42)[source]¶
Shuffle arrays in unison.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Feature matrix
y (array-like of shape (n_samples,), optional) – Labels
random_seed (int, default=42) – Random seed
- Returns:
Shuffled arrays
- Return type:
X_shuffled[, y_shuffled]
Examples
>>> X = jnp.array([[1, 2], [3, 4], [5, 6]]) >>> y = jnp.array([0, 1, 0]) >>> X_shuf, y_shuf = shuffle_jax(X, y, random_seed=42)
- prosemble.core.utils.k_fold_split_jax(n_samples, n_folds=5, random_seed=42)[source]¶
Generate K-fold cross-validation indices.
- Parameters:
- Yields:
train_indices, test_indices – Indices for each fold
Examples
>>> for train_idx, test_idx in k_fold_split_jax(100, n_folds=5): ... X_train, X_test = X[train_idx], X[test_idx]
- prosemble.core.utils.orthogonalize(matrix)[source]¶
Orthogonalize a matrix via polar decomposition (SVD).
Given a matrix A of shape (d, s), computes the closest orthogonal matrix Q such that Q^T Q = I, using the polar factor:
U, S, V^T = SVD(A) Q = U @ V^T
Supports batched input via jax.vmap.
- Parameters:
matrix (array of shape (d, s) or (n, d, s)) – Matrix or batch of matrices to orthogonalize. For batched input, use
jax.vmap(orthogonalize).- Returns:
Q – Orthogonal matrix (columns are orthonormal).
- Return type:
array of same shape as input
- prosemble.core.utils.class_priors(labels, n_classes=None)[source]¶
Compute class prior probabilities from labels.
P(class=k) = n_k / n
- Parameters:
labels (array of shape (n_samples,)) – Integer class labels.
n_classes (int, optional) – Number of classes. Inferred from labels if not provided.
- Returns:
priors – Prior probability for each class, sums to 1.
- Return type:
array of shape (n_classes,)
- prosemble.core.utils.prototype_priors(prototype_labels, n_classes=None)[source]¶
Compute class priors from prototype label distribution.
Used in probabilistic LVQ models where the prior over prototypes is uniform (1/n_prototypes) and the class prior is the fraction of prototypes assigned to each class.
P(class=k) = |{j : label_j = k}| / n_prototypes- Parameters:
prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.
n_classes (int, optional) – Number of classes. Inferred if not provided.
- Returns:
priors – Class prior probabilities, sums to 1.
- Return type:
array of shape (n_classes,)
- prosemble.core.utils.uniform_prototype_prior(n_prototypes)[source]¶
Uniform prior over prototypes: P(prototype_j) = 1/n.
This is the standard prior used in probabilistic LVQ (SLVQ/RSLVQ).
- Parameters:
n_prototypes (int) – Number of prototypes.
- Returns:
prior – Uniform probability vector, sums to 1.
- Return type:
array of shape (n_prototypes,)
Pipeline¶
- class prosemble.core.pipeline.NotFittedError[source]¶
Raised when calling transform/predict on an unfitted estimator.
- class prosemble.core.pipeline.StandardScaler[source]¶
Standardize features to zero mean and unit variance.
Wraps
prosemble.core.utils.standardize().Examples
>>> scaler = StandardScaler() >>> X_scaled = scaler.fit_transform(X_train) >>> X_test_scaled = scaler.transform(X_test)
- class prosemble.core.pipeline.MinMaxScaler[source]¶
Scale features to [0, 1] range.
Wraps
prosemble.core.utils.min_max_scale().Examples
>>> scaler = MinMaxScaler() >>> X_scaled = scaler.fit_transform(X_train)
- class prosemble.core.pipeline.PCA(n_components=2)[source]¶
Principal Component Analysis.
Wraps
prosemble.core.utils.pca_jax().- Parameters:
n_components (int, default=2) – Number of principal components to keep.
Examples
>>> pca = PCA(n_components=2) >>> X_reduced = pca.fit_transform(X)
- class prosemble.core.pipeline.Pipeline(steps)[source]¶
Chain transformers with a final estimator.
All operations use pure JAX arrays — no numpy roundtrips.
- Parameters:
steps (list of (name, estimator) tuples) – Sequence of transforms with a final estimator. All but the last must implement
fit()andtransform(). The last step can be any estimator or transformer.
Examples
>>> pipe = Pipeline([ ... ('scaler', StandardScaler()), ... ('pca', PCA(n_components=2)), ... ('model', GLVQ(n_prototypes_per_class=1, max_iter=50)), ... ]) >>> pipe.fit(X_train, y_train) >>> preds = pipe.predict(X_test)
- fit(X, y=None, **fit_params)[source]¶
Fit all steps.
Calls fit_transform on intermediate steps, fit on the final step.
- Parameters:
X (array of shape (n_samples, n_features))
y (array of shape (n_samples,), optional)
**fit_params (forwarded to the final estimator's fit().)
- Return type:
Model Selection¶
- prosemble.core.model_selection.clone(estimator, **override_params)[source]¶
Create a fresh (unfitted) copy of an estimator.
- Parameters:
estimator (object) – Estimator with
get_params()protocol.**override_params – Parameters to override in the clone.
- Returns:
new_estimator – Fresh, unfitted instance.
- Return type:
same type as estimator
Examples
>>> model = GLVQ(n_prototypes_per_class=2, lr=0.01) >>> model2 = clone(model, lr=0.05)
- prosemble.core.model_selection.cross_val_score(estimator, X, y=None, cv=5, scoring='accuracy', random_seed=42)[source]¶
Evaluate estimator with cross-validation.
- Parameters:
estimator (object) – Estimator with fit/predict and get_params.
X (array of shape (n_samples, n_features))
y (array of shape (n_samples,), optional) – Required for supervised estimators.
cv (int, default=5) – Number of cross-validation folds.
scoring (str or callable, default='accuracy') – ‘accuracy’ or callable
scorer(y_true, y_pred) -> float. For unsupervised without y:scorer(estimator, X_test) -> float.random_seed (int, default=42)
- Returns:
scores – Score for each fold.
- Return type:
jnp.ndarray of shape (cv,)
Examples
>>> scores = cross_val_score(GLVQ(max_iter=30), X, y, cv=5) >>> print(f"Mean: {scores.mean():.3f} +/- {scores.std():.3f}")
- class prosemble.core.model_selection.GridSearchCV(estimator, param_grid, cv=5, scoring='accuracy', random_seed=42, refit=True, verbose=0)[source]¶
Exhaustive search over a parameter grid with cross-validation.
- Parameters:
estimator (object) – Base estimator with fit/predict and get_params/set_params. Can be a Pipeline or any prosemble model.
param_grid (dict) – Maps parameter names to lists of values to try. For Pipeline steps, use
step_name__paramnotation.cv (int, default=5) – Number of cross-validation folds.
scoring (str or callable, default='accuracy') – ‘accuracy’ or callable
scorer(y_true, y_pred) -> float.random_seed (int, default=42)
refit (bool, default=True) – If True, refit the best model on the full dataset after search.
verbose (int, default=0) – 0=silent, 1=per-combo summary, 2=per-fold detail.
Examples
>>> gs = GridSearchCV( ... GLVQ(max_iter=30), ... {'n_prototypes_per_class': [1, 2], 'lr': [0.01, 0.05]}, ... cv=3, ... ) >>> gs.fit(X, y) >>> print(gs.best_params_, gs.best_score_)
Datasets¶
JAX-compatible dataset module.
Provides dataset loaders that return JAX arrays instead of NumPy arrays, optimized for use with JAX-based clustering models.
- class prosemble.datasets.dataset.DATASET_JAX(input_data, labels)[source]
JAX-compatible dataset container.
- input_data
Feature matrix as JAX array
- Type:
jax.numpy.ndarray
- labels
Labels as JAX array
- Type:
jax.numpy.ndarray
- input_data: Array
- labels: Array
- to_numpy()[source]
Convert JAX arrays back to NumPy arrays.
- prosemble.datasets.dataset.iris_dataset_jax(dtype=None)[source]
Load Iris dataset as JAX arrays.
- Parameters:
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (150 samples, 4 features, 3 classes)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import iris_dataset_jax >>> dataset = iris_dataset_jax() >>> dataset.input_data.shape (150, 4)
- prosemble.datasets.dataset.breast_cancer_dataset_jax(dtype=None)[source]
Load Wisconsin Breast Cancer dataset as JAX arrays.
- Parameters:
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (569 samples, 30 features, 2 classes)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import breast_cancer_dataset_jax >>> dataset = breast_cancer_dataset_jax() >>> dataset.input_data.shape (569, 30)
- prosemble.datasets.dataset.moons_dataset_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]
Generate two interleaving half circles (moons) as JAX arrays.
- Parameters:
n_samples (int, default=150) – Total number of points generated
noise (float, optional) – Standard deviation of Gaussian noise added to data
random_state (int, optional) – Random seed for reproducibility
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (n_samples, 2 features, 2 classes)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import moons_dataset_jax >>> dataset = moons_dataset_jax(n_samples=200, noise=0.1, random_state=42) >>> dataset.input_data.shape (200, 2)
- prosemble.datasets.dataset.blobs_dataset_jax(n_samples=None, centers=None, cluster_std=None, random_state=None, dtype=None)[source]
Generate isotropic Gaussian blobs as JAX arrays.
- Parameters:
n_samples (list, default=[120, 80]) – Number of samples per cluster
centers (list, default=[[0.0, 0.0], [2.0, 2.0]]) – Centers of the clusters
cluster_std (list, default=[1.2, 0.5]) – Standard deviation of clusters
random_state (int, optional) – Random seed for reproducibility
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
- Returns:
Dataset with JAX arrays (sum(n_samples), 2 features)
- Return type:
DATASET_JAX
Examples
>>> from prosemble.datasets import blobs_dataset_jax >>> dataset = blobs_dataset_jax(random_state=42) >>> dataset.input_data.shape (200, 2)
- class prosemble.datasets.dataset.DATA_JAX(random=4, sample_size=1000)[source]
JAX dataset collection with convenient property access.
- Parameters:
random (int, default=4) – Random seed for dataset generation
sample_size (int, default=1000) – Default sample size (not currently used)
dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features
Examples
>>> from prosemble.datasets import DATA_JAX >>> data = DATA_JAX(random=42) >>> moons = data.S_1 # Moons dataset >>> blobs = data.S_2 # Blobs dataset >>> cancer = data.breast_cancer # Breast cancer dataset
- random: int = 4
- sample_size: int = 1000
- dtype
alias of
float32
- property S_1: DATASET_JAX
Moons dataset.
- property S_2: DATASET_JAX
Blobs dataset.
- property iris: DATASET_JAX
Iris dataset.
- property breast_cancer: DATASET_JAX
Breast cancer dataset.
- property moons: DATASET_JAX
Alias for S_1 (moons dataset).
- property blobs: DATASET_JAX
Alias for S_2 (blobs dataset).
- prosemble.datasets.dataset.load_iris_jax(dtype=None)[source]
Quick loader for iris dataset.
- Return type:
DATASET_JAX
- prosemble.datasets.dataset.load_breast_cancer_jax(dtype=None)[source]
Quick loader for breast cancer dataset.
- Return type:
DATASET_JAX
- prosemble.datasets.dataset.load_moons_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]
Quick loader for moons dataset.
- prosemble.datasets.dataset.load_blobs_jax(random_state=None, dtype=None)[source]
Quick loader for blobs dataset.
- Parameters:
random_state (int | None)
- Return type:
DATASET_JAX
- prosemble.datasets.dataset.DATA
alias of
DATA_JAX
ONNX Export¶
ONNX export for prosemble prototype-based models.
Converts a fitted model’s predict function into an ONNX graph.
Only supports models whose distance function can be expressed with
standard ONNX operators. Unsupported models raise
NotImplementedError with a clear message.
Supported distance functions:
squared_euclidean_distance_matrixeuclidean_distance_matrixmanhattan_distance_matrixomega_distance_matrix(global projection matrix)lomega_distance_matrix(per-prototype local matrices)tangent_distance_matrix(per-prototype tangent subspace)relevance_weighted(per-feature relevance weighting)
Supported decision patterns:
WTAC (supervised classification)
ArgMin (unsupervised clustering)
One-class hard nearest (OCGLVQ family)
One-class Gaussian soft (OCRSLVQ family)
One-class Gaussian+NG soft (OCRSLVQ_NG family)
SVQ-OCC response model (SVQOCC family)
CBC reasoning (CBC, ImageCBC)
PLVQ Gaussian mixture soft assignment
Supported encoder models:
MLP encoder (SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN, PLVQ)
CNN encoder (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC)
Not supported:
gaussian_kernel_matrix,polynomial_kernel_matrixRiemannian manifold distances (logm, expm have no ONNX equivalent)
- prosemble.core.onnx_export.export_onnx(model, batch_size=1, opset_version=17, path=None)[source]¶
Export a fitted model’s predict function to ONNX format.
Builds an ONNX graph that reproduces the model’s
predict()output. Supports supervised (WTAC), unsupervised (ArgMin), one-class (threshold-based), SVQ-OCC (response model), CBC (reasoning matrices), PLVQ (Gaussian mixture), encoder models (MLP/CNN backbone), and Riemannian models on SO(n)/Grassmannian manifolds (73 of 87 models total).- Parameters:
model (SupervisedPrototypeModel or UnsupervisedPrototypeModel) – A fitted prosemble model.
batch_size (int) – Fixed batch dimension for the input. Use
-1for dynamic batch size (ONNX symbolic dimension).opset_version (int) – ONNX opset version. Default: 17.
path (str, optional) – If provided, save the ONNX model to this file path.
- Returns:
The exported ONNX model.
- Return type:
onnx.ModelProto
- Raises:
NotImplementedError – If the model’s distance function or decision pattern is not supported.
ImportError – If the
onnxpackage is not installed.
Manifolds¶
Riemannian manifold primitives for prototype learning.
Provides geodesic distance, exponential map, and logarithmic map for manifolds commonly arising in machine learning:
SO(n): Special orthogonal group (rotation matrices)
SPD(n): Symmetric positive definite matrices
Grassmannian Gr(n, k): k-dimensional subspaces of R^n
All operations are JIT-compilable and vmap-compatible.
References
Schwarz, Psenickova, Villmann, Röhrbein (2026). Topology-Preserving Prototype Learning on Riemannian Manifolds. ESANN 2026.
- prosemble.core.manifolds.logm_safe(A)[source]¶
Matrix logarithm via funm, with numerical safety.
Uses JAX’s
funmto compute logm. For complex eigenvalues (e.g. rotation matrices), operates in complex domain and returns the real part.- Parameters:
A (array of shape (..., n, n))
- Returns:
logA
- Return type:
array of shape (…, n, n)
- prosemble.core.manifolds.sqrt_spd(A)[source]¶
Matrix square root for symmetric positive definite matrices.
Uses eigendecomposition: \(A^{1/2} = V \operatorname{diag}(\sqrt{\lambda}) V^T\).
- Parameters:
A (array of shape (n, n), symmetric positive definite)
- Returns:
A_sqrt
- Return type:
array of shape (n, n)
- prosemble.core.manifolds.inv_sqrt_spd(A)[source]¶
Inverse matrix square root for symmetric positive definite matrices.
Uses eigendecomposition: \(A^{-1/2} = V \operatorname{diag}(1/\sqrt{\lambda}) V^T\).
- Parameters:
A (array of shape (n, n), symmetric positive definite)
- Returns:
A_inv_sqrt
- Return type:
array of shape (n, n)
- class prosemble.core.manifolds.SO(n)[source]¶
Special orthogonal group SO(n): rotation matrices.
Points are \(n imes n\) orthogonal matrices with det = +1. Geodesic distance uses the bi-invariant metric.
- Parameters:
n (int) – Dimension of the rotation group.
- log_map(R, S)[source]¶
Logarithmic map: Log_R(S) = R @ logm(R^T @ S).
Maps point S on the manifold to a tangent vector at R.
- exp_map(R, V)[source]¶
Exponential map: Exp_R(V) = R @ expm(R^T @ V).
Maps tangent vector V at R back to the manifold.
- class prosemble.core.manifolds.SPD(n)[source]¶
Manifold of \(n \times n\) symmetric positive definite matrices.
Uses the affine-invariant Riemannian metric.
- Parameters:
n (int) – Matrix dimension.
- class prosemble.core.manifolds.Grassmannian(n, k)[source]¶
Grassmannian manifold Gr(n, k): k-dimensional subspaces of R^n.
Points are represented as orthonormal bases Q of shape (n, k) with Q^T Q = I_k.
- distance(Q1, Q2)[source]¶
Geodesic distance via principal angles.
d(Q1, Q2) = ||theta||_2 where theta = arccos(svd(Q1^T Q2)).
- log_map(Q1, Q2)[source]¶
Logarithmic map on the Grassmannian.
Computes the tangent vector at Q1 pointing toward Q2 using aligned principal angle decomposition (Edelman et al. 1998).
The key insight is to align both subspaces via the SVD of Q1^T Q2, ensuring the principal angles and directions correspond.
Protocols¶
Structural typing protocols for prosemble interfaces.
Defines typing.Protocol contracts for duck-typed interfaces used
across the library. These enable static type checking (mypy / pyright)
and IDE auto-completion without requiring inheritance.
Note
Runtime-checkable protocols (Manifold, CallbackLike) support
isinstance() checks. Type aliases (DistanceMatrixFn, etc.)
are for annotation only.
- class prosemble.core.protocols.Manifold(*args, **kwargs)[source]¶
Protocol for Riemannian manifold implementations.
Any object exposing the methods below can be used wherever a manifold is expected (e.g.
RiemannianNeuralGas,RiemannianSRNG).The concrete implementations
SO,SPD, andGrassmannianall satisfy this protocol structurally — no explicit subclassing is required.- distance(p, q)[source]¶
Geodesic distance between two points.
- Parameters:
p (arrays of shape
point_shape)q (arrays of shape
point_shape)
- Return type:
scalar
- distance_squared(p, q)[source]¶
Squared geodesic distance between two points.
- Parameters:
p (arrays of shape
point_shape)q (arrays of shape
point_shape)
- Return type:
scalar
- log_map(base, target)[source]¶
Logarithmic map: tangent vector at base pointing toward target.
- Parameters:
base (array of shape
point_shape) – Base point on the manifold.target (array of shape
point_shape) – Target point on the manifold.
- Returns:
tangent – Tangent vector in \(T_{\text{base}} M\).
- Return type:
array of shape
point_shape
- exp_map(base, tangent)[source]¶
Exponential map: move along tangent from base back to the manifold.
- Parameters:
base (array of shape
point_shape)tangent (array of shape
point_shape)
- Returns:
point
- Return type:
array of shape
point_shape
- random_point(key)[source]¶
Sample a random point on the manifold.
- Parameters:
key (JAX PRNG key)
- Returns:
point
- Return type:
array of shape
point_shape
- belongs(point)[source]¶
Check whether point lies on the manifold.
- Parameters:
point (array of shape
point_shape)- Return type:
bool or bool-valued array
- class prosemble.core.protocols.CallbackLike(*args, **kwargs)[source]¶
Protocol for training callbacks.
Any object with the three hook methods below can be passed in the
callbackslist of a model’s constructor. The existingCallbackbase class already satisfies this protocol.
- prosemble.core.protocols.DistanceMatrixFn¶
(X, Y) -> distances.Xhas shape(n_samples, n_features),Yhas shape(n_prototypes, n_features), result has shape(n_samples, n_prototypes).- Type:
Distance-matrix function
- prosemble.core.protocols.DistancePairwiseFn¶
(x, y) -> scalar.- Type:
Pairwise distance function
Distributed Training¶
Distributed training utilities for multi-device data parallelism.
Provides functions for sharding data across devices and replicating model parameters, enabling data-parallel training on multi-GPU/TPU setups.
- prosemble.core.distributed.create_mesh(devices=None)[source]¶
Create a 1D device mesh for data parallelism.
- Parameters:
devices (list of jax.Device or None) – Devices to use. If None, returns None (single-device mode).
- Return type:
Mesh or None
- prosemble.core.distributed.shard_data(X, y, mesh)[source]¶
Shard data arrays along the batch dimension across devices.
- Parameters:
X (jnp.ndarray of shape (n_samples, ...)) – Input data.
y (jnp.ndarray of shape (n_samples,)) – Labels.
mesh (Mesh) – Device mesh from create_mesh().
- Returns:
X_sharded, y_sharded
- Return type:
tuple of sharded arrays
- prosemble.core.distributed.replicate_params(params, mesh)[source]¶
Replicate params across all devices (no partitioning).
- Parameters:
params (dict (pytree)) – Model parameters.
mesh (Mesh) – Device mesh from create_mesh().
- Return type:
params replicated across mesh
Data Utilities¶
Data loading and batching utilities for prosemble.
Provides composable primitives for mini-batch iteration, suitable for
custom training loops. The built-in fit() methods already handle
batching internally — these utilities are for advanced users who need
explicit control over data iteration.
- prosemble.core.data.shuffle_arrays(key, *arrays)[source]
Shuffle multiple arrays with the same random permutation.
- Parameters:
key (JAX PRNG key) – Random key for generating the permutation.
*arrays (jnp.ndarray) – Arrays to shuffle. All must have the same length along axis 0.
- Returns:
Shuffled arrays in the same order as the inputs.
- Return type:
tuple of jnp.ndarray
Examples
>>> key = jax.random.PRNGKey(0) >>> X = jnp.arange(12).reshape(4, 3) >>> y = jnp.array([0, 1, 0, 1]) >>> X_s, y_s = shuffle_arrays(key, X, y)
- prosemble.core.data.padded_batches(X, y=None, batch_size=32, key=None)[source]
Split data into static-shaped batches, padding the last batch.
Returns arrays of shape
(n_batches, batch_size, ...)suitable forjax.lax.scan. If the data length is not divisible bybatch_size, the last batch is padded by repeating initial samples.- Parameters:
X (jnp.ndarray of shape (n_samples, ...)) – Feature array.
y (jnp.ndarray of shape (n_samples,), optional) – Label array.
batch_size (int) – Number of samples per batch.
key (JAX PRNG key, optional) – If provided, data is shuffled before batching.
- Returns:
X_batches (jnp.ndarray of shape (n_batches, batch_size, …))
y_batches (jnp.ndarray of shape (n_batches, batch_size) or None)
- Return type:
Examples
>>> X = jnp.ones((10, 3)) >>> y = jnp.arange(10) >>> X_b, y_b = padded_batches(X, y, batch_size=4) >>> X_b.shape # (3, 4, 3) — 10 samples padded to 12 (3, 4, 3)
- prosemble.core.data.batched_iterator(X, y=None, batch_size=32, shuffle=True, key=None, drop_last=False)[source]
Yield mini-batches from data arrays.
For use in custom Python training loops.
- Parameters:
X (jnp.ndarray of shape (n_samples, ...)) – Feature array.
y (jnp.ndarray of shape (n_samples,), optional) – Label array.
batch_size (int) – Number of samples per batch.
shuffle (bool) – If True and key is provided, shuffle before iterating.
key (JAX PRNG key, optional) – Required if
shuffle=True.drop_last (bool) – If True, drop the last batch when it is smaller than
batch_size. If False (default), the last batch may be smaller.
- Yields:
X_batch (jnp.ndarray of shape (batch_size, …) or smaller)
y_batch (jnp.ndarray of shape (batch_size,) or None)
Examples
>>> key = jax.random.PRNGKey(0) >>> X = jnp.ones((10, 3)) >>> for X_b, y_b in batched_iterator(X, batch_size=4, key=key): ... print(X_b.shape) (4, 3) (4, 3) (2, 3)