Core API Reference

Core modules provide the building blocks used by all models: distance functions, loss functions, initializers, and utilities.

Distance Functions

JAX-based distance functions for Prosemble.

This module provides GPU-accelerated, vectorized distance computations using JAX. All functions are JIT-compiled for maximum performance.

Mathematical Background

Distance metrics are fundamental to prototype-based learning algorithms. This implementation focuses on: 1. Batch/matrix operations (no Python loops) 2. GPU compatibility 3. JIT compilation for speed 4. Numerical stability

Author: Prosemble Contributors License: MIT

prosemble.core.distance.euclidean_distance_matrix(X, Y)[source]

Compute pairwise Euclidean distances between rows of X and Y.

Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).

Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)

Parameters:
Returns:

Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)

Return type:

D

Complexity:

Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix

Example

>>> X = jnp.array([[0, 0], [1, 1], [2, 2]])
>>> Y = jnp.array([[0, 0], [3, 3]])
>>> D = euclidean_distance_matrix(X, Y)
>>> D.shape
(3, 2)
>>> D[0, 0]  # Distance from X[0] to Y[0]
0.0
>>> D[2, 1]  # Distance from X[2] to Y[1]
1.414...

Notes

  • Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives

  • GPU-compatible: All operations are JAX primitives

  • JIT-compiled: First call compiles, subsequent calls are fast

prosemble.core.distance.squared_euclidean_distance_matrix(X, Y)[source]

Compute pairwise squared Euclidean distances.

More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.

Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)

Parameters:
Returns:

Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)

Return type:

D

Complexity:

Time: O(nmd) Space: O(nm)

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D_sq = squared_euclidean_distance_matrix(X, Y)
>>> D_sq[1, 1]  # Squared distance from [1,1] to [2,2]
2.0

Notes

  • Preferred over euclidean when squared distances are sufficient

  • Many algorithms (FCM, PCM) use squared distances directly

  • More numerically stable than squaring euclidean distances

prosemble.core.distance.manhattan_distance_matrix(X, Y)[source]

Compute pairwise Manhattan (L1) distances.

Formula: \(D_{ij} = \|X_i - Y_j\|_1 = \sum_k |X_{ik} - Y_{jk}|\)

Parameters:
Returns:

Array of shape (n, m) where D[i,j] is Manhattan distance

Return type:

D

Complexity:

Time: O(nmd) Space: O(nmd) - intermediate broadcasting

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D = manhattan_distance_matrix(X, Y)
>>> D[1, 1]  # Manhattan distance from [1,1] to [2,2]
2.0
Implementation:

Uses broadcasting: X[:, None, :] - Y[None, :, :] creates (n, m, d) Then sums absolute differences along feature dimension.

Notes

  • Also known as “taxicab” or “city block” distance

  • More robust to outliers than Euclidean distance

  • Natural for sparse/binary features

prosemble.core.distance.lpnorm_distance_matrix(X, Y, p)[source]

Compute pairwise L-p norm distances.

Formula: \(D_{ij} = \|X_i - Y_j\|_p = (\sum_k |X_{ik} - Y_{jk}|^p)^{1/p}\)

Special Cases:

p = 1: Manhattan distance p = 2: Euclidean distance p = \(\infty\): Chebyshev distance (max absolute difference)

Parameters:
Returns:

Array of shape (n, m) where D[i,j] is L-p distance

Return type:

D

Complexity:

Time: O(nmd) Space: O(nmd)

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [3, 4]])
>>> D = lpnorm_distance_matrix(X, Y, p=2)  # Euclidean
>>> D = lpnorm_distance_matrix(X, Y, p=1)  # Manhattan
>>> D = lpnorm_distance_matrix(X, Y, p=jnp.inf)  # Chebyshev

Notes

  • For p=1, use manhattan_distance_matrix for better performance

  • For p=2, use euclidean_distance_matrix for better performance

  • For p=inf, computes max(|x - y|)

prosemble.core.distance.omega_distance_matrix(X, Y, omega)[source]

Compute distances in projected space using projection matrix Omega.

Formula: \(D_{ij} = \|X_i \Omega - Y_j \Omega\|^2\) where \(\Omega\) is a projection matrix that transforms the feature space.

Parameters:
Returns:

Array of shape (n, m) with squared distances in projected space

Return type:

D

Complexity:

Time: O(ndk + mdk + nmk) = O((n+m)dk + nmk) Space: O(nk + mk + nm)

Use Cases:
  • Dimensionality reduction for distance computation

  • Learning relevance of features (omega learned from data)

  • Mahalanobis-like distances (when \(\Omega = L\) where \(\Sigma = LL^T\))

Example

>>> X = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> Y = jnp.array([[0, 0, 0], [1, 1, 1]])
>>> omega = jnp.array([[1, 0], [0, 1], [0, 0]])  # Project to first 2 dims
>>> D = omega_distance_matrix(X, Y, omega)
>>> D.shape
(2, 2)

Notes

  • When omega is identity, reduces to squared Euclidean distance

  • When omega is learned, enables adaptive distance metrics

  • Used in GLVQ (Generalized Learning Vector Quantization)

prosemble.core.distance.lomega_distance_matrix(X, Y, omegas)[source]

Compute distances using multiple projection matrices (Local Omega).

Formula: \(D_{ij} = \sum_p \|X_i \Omega_p - Y_j \Omega_p\|^2\) where \(\Omega_p\) are multiple projection matrices (one per prototype or cluster).

Parameters:
Returns:

Array of shape (n, m) with aggregated projected distances

Return type:

D

Complexity:

Time: O(nmdk) Space: O(nmk)

Use Cases:
  • Local relevance learning (each prototype has its own metric)

  • Adaptive distance metrics in GMLVQ

  • Cluster-specific feature weighting

Example

>>> n, m, d, k = 10, 3, 5, 2
>>> X = jax.random.normal(jax.random.PRNGKey(0), (n, d))
>>> Y = jax.random.normal(jax.random.PRNGKey(1), (m, d))
>>> omegas = jax.random.normal(jax.random.PRNGKey(2), (m, d, k))
>>> D = lomega_distance_matrix(X, Y, omegas)
>>> D.shape
(10, 3)
Implementation:

Uses einsum for efficient tensor contraction: 1. Project X through each omega: X @ omegas[j] for all j 2. Extract diagonal for Y projections (each Y[j] uses omegas[j]) 3. Compute squared differences and sum

Notes

  • Generalizes omega_distance to local (per-prototype) metrics

  • More flexible but computationally expensive

  • Enables learning which features matter for each cluster

prosemble.core.distance.tangent_distance_matrix(X, Y, omegas)[source]

Compute pairwise localized tangent distances.

Each prototype j has an orthogonal subspace basis \(\Omega_j\) of shape (d, s). The tangent distance projects out the subspace directions:

\[d(x, w_j) = \|(I - \Omega_j \Omega_j^T)(x - w_j)\|^2\]
Parameters:
  • X (array of shape (n, d)) – Data points.

  • Y (array of shape (m, d)) – Prototypes.

  • omegas (array of shape (m, d, s)) – Orthogonal subspace bases per prototype, where s is the subspace dimension.

Returns:

D – Squared tangent distances.

Return type:

array of shape (n, m)

Notes

Based on Saralajew, S., & Villmann, T. (2016). Adaptive tangent distances in generalized learning vector quantization.

prosemble.core.distance.gaussian_kernel_matrix(X, Y, sigma)[source]

Compute Gaussian (RBF) kernel matrix.

Formula: \(K_{ij} = \exp(-\|X_i - Y_j\|^2 / (2\sigma^2))\)

The Gaussian kernel maps data to infinite-dimensional Hilbert space, enabling non-linear clustering and classification.

Args:

X: Array of shape (n, d) Y: Array of shape (m, d) sigma: Bandwidth parameter (\(\sigma > 0\))

Returns:
K: Array of shape (n, m) where \(K_{ij} \in [0, 1]\).

\(K_{ij} = 1\) when \(X_i = Y_j\); \(K_{ij} o 0\) as \(\|X_i - Y_j\| o \infty\)

Complexity:

Time: O(nmd) Space: O(nm)

Properties:
  • K is positive semi-definite (valid kernel)

  • K is symmetric if X = Y

  • K[i,i] = 1 (self-similarity)

Example:
>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> K = gaussian_kernel_matrix(X, Y, sigma=1.0)
>>> K[0, 0]  # Self-similarity
1.0
>>> K[0, 1] < K[0, 0]  # Decreases with distance
True
Kernel Trick:

For feature map \(\phi\) mapping to infinite-dimensional Hilbert space, :math:`K(x, y) = langlephi(x), phi(y)

angle`. Kernel distance:

\(\|\phi(x) - \phi(y)\|^2 = K(x,x) - 2K(x,y) + K(y,y) = 2 - 2K(x,y)\) for normalized kernel.

Use Cases:
  • Kernel Fuzzy C-Means (KFCM)

  • Kernel Possibilistic C-Means (KPCM)

  • Support Vector Machines (SVM)

  • Gaussian Processes

Hyperparameter Tuning:
  • Small \(\sigma\): Tight clusters, high sensitivity to noise

  • Large \(\sigma\): Smooth clusters, may underfit

  • Rule of thumb: \(\sigma pprox ext{median}( ext{pairwise\_distances}) / \sqrt{2 \cdot n_ ext{clusters}}\)

Notes:
  • sigma is bandwidth, NOT variance (variance = \(\sigma^2\))

  • For numerical stability, we use maximum() to ensure non-negative

  • JIT-compiled for GPU acceleration

Parameters:
Return type:

Array | ndarray | bool | number

prosemble.core.distance.polynomial_kernel_matrix(X, Y, degree=3, coef0=1.0)[source]

Compute polynomial kernel matrix.

Formula: \(K_{ij} = (X_i^T Y_j + c)^d\) where d is degree and c is coef0.

Parameters:
Returns:

Array of shape (n, m) with polynomial kernel values

Return type:

K

Example

>>> X = jnp.array([[1, 2], [3, 4]])
>>> Y = jnp.array([[1, 0], [0, 1]])
>>> K = polynomial_kernel_matrix(X, Y, degree=2, coef0=1.0)

Notes

  • degree=1, coef0=0: Linear kernel (dot product)

  • Higher degree: More complex decision boundaries

  • coef0: Influences importance of lower vs higher order terms

prosemble.core.distance.euclidean_distance(x, y)[source]

Euclidean distance between two vectors.

Parameters:
Returns:

Scalar distance

Return type:

Array | ndarray | bool | number

Example

>>> x = jnp.array([0, 0, 0])
>>> y = jnp.array([1, 1, 1])
>>> d = euclidean_distance(x, y)
>>> d
Array(1.732..., dtype=float32)
prosemble.core.distance.squared_euclidean_distance(x, y)[source]

Squared Euclidean distance between two vectors.

Parameters:
Returns:

Scalar squared distance

Return type:

Array | ndarray | bool | number

prosemble.core.distance.manhattan_distance(x, y)[source]

Manhattan (L1) distance between two vectors.

Parameters:
Returns:

Scalar distance

Return type:

Array | ndarray | bool | number

prosemble.core.distance.lpnorm_distance(x, y, p=2)[source]

Lp-norm distance between two vectors.

Parameters:
Returns:

Scalar distance

Return type:

Array | ndarray | bool | number

prosemble.core.distance.omega_distance(x, y, omega)[source]

Omega (projection-based) distance between two vectors.

Computes \(\| ext{diff} \cdot \Omega\|^2\) where :math:` ext{diff} = x - y`.

Parameters:
Returns:

Scalar squared distance in projected space

Return type:

Array | ndarray | bool | number

prosemble.core.distance.lomega_distance(X, Y, omegas)[source]

Local omega distance with per-prototype projection matrices.

Parameters:
Returns:

Distance matrix of shape (n, m)

Return type:

Array | ndarray | bool | number

prosemble.core.distance.estimate_sigma(X, percentile=50.0)[source]

Estimate sigma for Gaussian kernel using pairwise distances.

Strategy: Use median (or other percentile) of pairwise distances.

Parameters:
Returns:

Estimated bandwidth parameter

Return type:

sigma

Example

>>> X = jax.random.normal(jax.random.PRNGKey(0), (100, 10))
>>> sigma = estimate_sigma(X, percentile=50)

Notes

  • Heuristic: \(\sigma = ext{median\_distance} / \sqrt{2 \cdot n_ ext{clusters}}\)

  • For large datasets, use subsample to avoid O(n²) computation

prosemble.core.distance.safe_divide(numerator, denominator, epsilon=1e-10)[source]

Safe division avoiding division by zero.

Parameters:
Returns:

numerator / (denominator + epsilon)

Return type:

Array | ndarray | bool | number

Example

>>> x = jnp.array([1.0, 2.0, 3.0])
>>> y = jnp.array([2.0, 0.0, 1.0])
>>> safe_divide(x, y)
Array([0.5, 2e+09, 3.0], dtype=float32)  # Avoids inf
prosemble.core.distance.batch_squared_euclidean(X, Y)

Compute pairwise squared Euclidean distances.

More efficient than euclidean_distance_matrix(X, Y)**2 because it avoids the sqrt operation entirely.

Formula: \(D^2_{ij} = \|X_i - Y_j\|^2 = \sum_k (X_{ik} - Y_{jk})^2\)

Parameters:
Returns:

Array of shape (n, m) where \(D^2_{ij} = \|X_i - Y_j\|^2\)

Return type:

D

Complexity:

Time: O(nmd) Space: O(nm)

Example

>>> X = jnp.array([[0, 0], [1, 1]])
>>> Y = jnp.array([[0, 0], [2, 2]])
>>> D_sq = squared_euclidean_distance_matrix(X, Y)
>>> D_sq[1, 1]  # Squared distance from [1,1] to [2,2]
2.0

Notes

  • Preferred over euclidean when squared distances are sufficient

  • Many algorithms (FCM, PCM) use squared distances directly

  • More numerically stable than squaring euclidean distances

prosemble.core.distance.batch_euclidean(X, Y)

Compute pairwise Euclidean distances between rows of X and Y.

Uses the expansion trick: \(\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2 x^T y\).

Formula: \(D_{ij} = \|X_i - Y_j\| = \sqrt{\sum_k (X_{ik} - Y_{jk})^2}\)

Parameters:
Returns:

Array of shape (n, m) where \(D_{ij} = \|X_i - Y_j\|\)

Return type:

D

Complexity:

Time: O(nmd) - single matrix multiplication Space: O(nm) - output matrix

Example

>>> X = jnp.array([[0, 0], [1, 1], [2, 2]])
>>> Y = jnp.array([[0, 0], [3, 3]])
>>> D = euclidean_distance_matrix(X, Y)
>>> D.shape
(3, 2)
>>> D[0, 0]  # Distance from X[0] to Y[0]
0.0
>>> D[2, 1]  # Distance from X[2] to Y[1]
1.414...

Notes

  • Numerically stable: Uses maximum(D_sq, 0) to avoid sqrt of negatives

  • GPU-compatible: All operations are JAX primitives

  • JIT-compiled: First call compiles, subsequent calls are fast

Similarities

Similarity functions for prototype-based learning.

Similarities are the dual of distances: higher values indicate closer/more similar points.

prosemble.core.similarities.gaussian_similarity(distances_sq, variance=1.0)[source]

Convert squared distances to Gaussian similarities.

\(s(d) = \exp(-d^2 / (2 \cdot ext{variance}))\)

Parameters:
  • distances_sq (array) – Squared distances.

  • variance (float) – Variance (sigma^2) of the Gaussian.

Returns:

Similarity values in (0, 1].

Return type:

array

prosemble.core.similarities.cosine_similarity_matrix(X, Y)[source]

Pairwise cosine similarity between rows of X and Y.

:math:`cos(x, y) =

rac{x cdot y}{|x| cdot |y|}`

X : array of shape (n, d) Y : array of shape (m, d)

array of shape (n, m)

Cosine similarities in [-1, 1].

prosemble.core.similarities.euclidean_similarity(X, Y, variance=1.0)[source]

Pairwise Euclidean similarity (Gaussian of Euclidean distance).

Parameters:
  • X (array of shape (n, d))

  • Y (array of shape (m, d))

  • variance (float) – Variance of the Gaussian kernel.

Returns:

Similarity values in (0, 1].

Return type:

array of shape (n, m)

prosemble.core.similarities.rank_scaled_gaussian(distances, lambd=1.0)[source]

Rank-scaled Gaussian similarity.

Combines distance magnitude with rank ordering: closer prototypes (lower rank) receive a stronger signal, while farther ones are exponentially suppressed.

\[s(d, r) = \exp(-\exp(-r / \lambda) \cdot d)\]

where r is the rank of each distance (0 = closest).

Parameters:
  • distances (array of shape (n, m)) – Distance matrix (non-negative).

  • lambd (float) – Rank decay parameter. Larger values give more uniform weighting across ranks; smaller values concentrate on nearest neighbours.

Returns:

Rank-scaled similarity values in (0, 1].

Return type:

array of shape (n, m)

Notes

Used in Probabilistic LVQ (PLVQ) as a conditional distribution P(x|prototype).

Kernel Functions

JAX-based kernel functions for kernel clustering methods.

This module provides GPU-accelerated kernel computations using JAX.

prosemble.core.kernel.gaussian_kernel(x, y, sigma)[source]

Compute Gaussian (RBF) kernel between two vectors.

\(K(x, y) = \exp(-\|x - y\|^2 / (2\sigma^2))\)

Parameters:
Returns:

Kernel value (scalar)

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.batch_gaussian_kernel(X, Y, sigma)[source]

Compute Gaussian kernel between two sets of vectors.

Parameters:
Returns:

Kernel matrix, shape (n_samples, m_samples) K[i, j] = K(X[i], Y[j])

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.kernel_distance_squared(X, Y, sigma)[source]

Compute squared distance in feature space.

For Gaussian kernel: \(\|\phi(x) - \phi(y)\|^2 = K(x,x) + K(y,y) - 2K(x,y) = 2(1 - K(x,y))\), since \(K(x,x) = 1\).

Parameters:
Returns:

Squared distances in feature space, shape (n_samples, m_samples)

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.kernel_distance(X, Y, sigma)[source]

Compute distance in feature space.

Parameters:
Returns:

Distances in feature space, shape (n_samples, m_samples)

Return type:

Array | ndarray | bool | number

prosemble.core.kernel.kernel_distance_squared_per_proto(X, W, sigmas)[source]

Squared kernel distance with per-prototype bandwidth.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\|x - w_k\|^2}{2\sigma_k^2}\right)\right)\]

Each prototype \(w_k\) has its own bandwidth \(\sigma_k\).

Parameters:
Returns:

Squared distances in feature space, shape (n_samples, n_prototypes).

Return type:

Array | ndarray | bool | number

References

prosemble.core.kernel.kernel_distance_squared_relevance(X, W, sigmas, relevances)[source]

Squared kernel distance with relevance weighting and per-prototype bandwidth.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(-\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2}\right)\right)\]
Parameters:
Returns:

Squared distances in feature space, shape (n_samples, n_prototypes).

Return type:

Array | ndarray | bool | number

References

prosemble.core.kernel.exponential_kernel_distance_squared(X, W, omega_hat)[source]

Squared distance in exponential kernel feature space.

Uses the exponential kernel \(\kappa_{\exp}(v, w, \hat\Lambda) = \exp(v^T \hat\Lambda w)\) where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]

Note: \(\kappa(v, v) \neq 1\) for the exponential kernel, so the full three-term formula is required (not the 2(1-K) simplification).

Parameters:
  • X (Array | ndarray | bool | number) – Data matrix, shape (n_samples, n_features).

  • W (Array | ndarray | bool | number) – Prototype matrix, shape (n_prototypes, n_features).

  • omega_hat (Array | ndarray | bool | number) – Transformation matrix, shape (n_features, latent_dim). The kernel matrix is \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

Returns:

Squared distances in feature space, shape (n_samples, n_prototypes).

Return type:

Array | ndarray | bool | number

References

Loss Functions

Loss functions for prototype-based learning.

All loss functions are differentiable by jax.grad and JIT-compatible. They use jnp.where masking (not boolean indexing) for d+/d- extraction.

prosemble.core.losses.glvq_loss(distances, target_labels, prototype_labels, margin=0.0)[source]

Generalized LVQ loss.

:math:`mu_i =

rac{d^+_i - d^-_i}{d^+_i + d^-_i}`

distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) margin : float

Margin added to mu before transfer.

scalar

Mean loss over samples.

prosemble.core.losses.glvq_loss_with_transfer(distances, target_labels, prototype_labels, transfer_fn=<PjitFunction of <function identity>>, margin=0.0, beta=10.0)[source]

GLVQ loss with configurable transfer function.

:math:` ext{loss} = ext{mean}(f(mu + ext{margin}, eta))`

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • transfer_fn (callable) – Activation function (identity, sigmoid_beta, swish_beta).

  • margin (float)

  • beta (float) – Transfer function parameter.

Return type:

scalar

prosemble.core.losses.lvq1_loss(distances, target_labels, prototype_labels)[source]

LVQ1 loss: d+ when correct, -d- when wrong.

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

Return type:

scalar

prosemble.core.losses.lvq21_loss(distances, target_labels, prototype_labels)[source]

LVQ2.1 loss: d+ - d- (unnormalized).

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

Return type:

scalar

prosemble.core.losses.nllr_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]

Negative Log-Likelihood Ratio loss (for SLVQ).

:math:` ext{loss} = -log(P( ext{correct}) / P( ext{wrong}))`

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • sigma (float) – Bandwidth of Gaussian mixture.

Return type:

scalar

prosemble.core.losses.rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0)[source]

Robust Soft LVQ loss (for RSLVQ).

:math:` ext{loss} = -log(P( ext{correct}) / P( ext{all}))`

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • sigma (float)

Return type:

scalar

prosemble.core.losses.ng_rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0, gamma=1.0)[source]

RSLVQ loss with Neural Gas rank-based neighborhood cooperation.

Combines Gaussian mixture prototype probabilities with NG rank weights to create a neighborhood-cooperative probabilistic assignment.

Gaussian: \(p(k|x) = \exp(-d_k / 2\sigma^2) / \sum_j \exp(-d_j / 2\sigma^2)\)

NG weights: \(h_k = \exp(- ext{rank}_k / \gamma) / \sum_j \exp(- ext{rank}_j / \gamma)\)

Combined: \(w_k = p(k|x) \cdot h_k / \sum_j p(j|x) \cdot h_j\)

The loss is \(-\log(\sum_{k \in ext{correct}} w_k)\).

Parameters:
  • distances (array of shape (n, p)) – Squared distances from samples to prototypes.

  • target_labels (array of shape (n,)) – True class labels for samples.

  • prototype_labels (array of shape (p,)) – Class labels assigned to prototypes.

  • sigma (float) – Bandwidth of Gaussian mixture.

  • gamma (float) – Neural Gas neighborhood range.

Returns:

Mean negative log-likelihood with NG cooperation.

Return type:

scalar

prosemble.core.losses.cross_entropy_lvq_loss(distances, target_labels, prototype_labels, n_classes)[source]

Cross-entropy LVQ loss (for CELVQ).

  1. Min distances per class via masking

  2. Negate to get logits (closer = higher)

  3. Cross-entropy against true labels

Parameters:
  • distances (array of shape (n, p))

  • target_labels (array of shape (n,))

  • prototype_labels (array of shape (p,))

  • n_classes (int)

Return type:

scalar

prosemble.core.losses.margin_loss(y_pred, y_true_one_hot, margin=0.3)[source]

Margin loss for CBC.

:math:` ext{loss} = ext{ReLU}(max( ext{wrong}) - ext{correct} + ext{margin})`

Parameters:
  • y_pred (array of shape (n, n_classes)) – Predicted class probabilities.

  • y_true_one_hot (array of shape (n, n_classes)) – One-hot encoded true labels.

  • margin (float)

Return type:

scalar

prosemble.core.losses.neural_gas_energy(distances, lam)[source]

Neural Gas energy function.

\[E = \sum_k h( ext{rank}_k, \lambda) \cdot d(x, w_k), \quad h( ext{rank}, \lambda) = \exp(- ext{rank} / \lambda)\]
Parameters:
  • distances (array of shape (n, p))

  • lam (float) – Neighborhood range parameter.

Return type:

scalar

Activations

Transfer/activation functions for prototype-based learning.

These functions are used to shape the GLVQ loss (mu values) before summation, controlling the optimization landscape.

prosemble.core.activations.identity(x, beta=0.0)[source]

Identity activation (passthrough).

Parameters:
  • x (array) – Input values.

  • beta (float) – Ignored. Present for API consistency.

Returns:

Same as input.

Return type:

array

prosemble.core.activations.sigmoid_beta(x, beta=10.0)[source]

Sigmoid activation with steepness parameter.

f(x) = 1 / (1 + exp(-beta * x))

Parameters:
  • x (array) – Input values.

  • beta (float) – Steepness parameter. Higher values give sharper transition.

Returns:

Sigmoid-transformed values in (0, 1).

Return type:

array

prosemble.core.activations.swish_beta(x, beta=10.0)[source]

Swish activation with steepness parameter.

f(x) = x * sigmoid(beta * x)

Parameters:
  • x (array) – Input values.

  • beta (float) – Steepness parameter.

Returns:

Swish-transformed values.

Return type:

array

prosemble.core.activations.get_activation(name)[source]

Get activation function by name.

Parameters:

name (str or callable) – Name of activation (‘identity’, ‘sigmoid_beta’, ‘swish_beta’) or a callable.

Returns:

The activation function.

Return type:

callable

Competitions

Competition mechanisms for prototype-based classification.

These functions determine class predictions from distance matrices and prototype labels using different strategies.

prosemble.core.competitions.wtac(distances, prototype_labels)[source]

Winner-Takes-All Competition.

Assigns each sample the label of the closest prototype.

Parameters:
  • distances (array of shape (n_samples, n_prototypes)) – Distance matrix.

  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

Returns:

Predicted labels.

Return type:

array of shape (n_samples,)

prosemble.core.competitions.knnc(distances, prototype_labels, k=1, n_classes=None)[source]

K-Nearest Neighbors Competition.

Assigns each sample the majority label among k closest prototypes.

Parameters:
  • distances (array of shape (n_samples, n_prototypes)) – Distance matrix.

  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • k (int) – Number of neighbors.

  • n_classes (int or None) – Number of classes. If None, inferred from prototype_labels.

Returns:

Predicted labels.

Return type:

array of shape (n_samples,)

prosemble.core.competitions.cbcc(detections, reasonings)[source]

Classification-By-Components Competition.

Computes class probability distributions using component detections and reasoning matrices.

Parameters:
  • detections (array of shape (n_samples, n_components)) – Similarity/detection scores for each component.

  • reasonings (array of shape (n_components, n_classes, 2)) – Reasoning matrices. Last dim: [positive, negative_raw].

Returns:

Class probability distributions.

Return type:

array of shape (n_samples, n_classes)

Initializers

Prototype and parameter initializers for prototype-based learning.

These functions generate initial prototypes, labels, and transformation matrices for supervised and unsupervised models.

prosemble.core.initializers.stratified_selection_init(X, y, n_per_class, key)[source]

Initialize prototypes by randomly selecting samples per class.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • y (array of shape (n_samples,)) – Training labels.

  • n_per_class (int, list, or dict) – Number of prototypes per class. - int: same count for all classes. - list: index i gives the count for class i, e.g. [2, 2, 1]. - dict: maps class label to count, e.g. {0: 2, 1: 2, 2: 1}.

  • key (jax.random.PRNGKey) – Random key.

Returns:

  • prototypes (array of shape (n_prototypes, n_features))

  • prototype_labels (array of shape (n_prototypes,))

prosemble.core.initializers.stratified_mean_init(X, y)[source]

Initialize prototypes at the mean of each class.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • y (array of shape (n_samples,)) – Training labels.

Returns:

  • prototypes (array of shape (n_classes, n_features))

  • prototype_labels (array of shape (n_classes,))

prosemble.core.initializers.random_normal_init(n_prototypes, n_features, key, mean=0.0, std=1.0)[source]

Initialize prototypes from a normal distribution.

Parameters:
  • n_prototypes (int)

  • n_features (int)

  • key (jax.random.PRNGKey)

  • mean (float)

  • std (float)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.identity_omega_init(n_features, n_dims=None)[source]

Initialize omega as an identity (or truncated identity) matrix.

Parameters:
  • n_features (int) – Input dimensionality.

  • n_dims (int, optional) – Projection dimensionality. Defaults to n_features (square).

Return type:

array of shape (n_features, n_dims)

prosemble.core.initializers.random_omega_init(n_features, n_dims, key)[source]

Initialize omega as a random orthogonal matrix via QR decomposition.

Parameters:
  • n_features (int) – Input dimensionality.

  • n_dims (int) – Projection dimensionality.

  • key (jax.random.PRNGKey)

Return type:

array of shape (n_features, n_dims)

prosemble.core.initializers.uniform_init(n_prototypes, n_features, key, low=0.0, high=1.0)[source]

Initialize prototypes from a uniform distribution.

Parameters:
  • n_prototypes (int)

  • n_features (int)

  • key (jax.random.PRNGKey)

  • low (float)

  • high (float)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.zeros_init(n_prototypes, n_features)[source]

Initialize prototypes as zeros.

Useful for checkpoint loading where shapes must be pre-allocated before restoring saved values.

Parameters:
  • n_prototypes (int)

  • n_features (int)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.ones_init(n_prototypes, n_features)[source]

Initialize prototypes as ones.

Parameters:
  • n_prototypes (int)

  • n_features (int)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.fill_value_init(n_prototypes, n_features, value=0.0)[source]

Initialize prototypes filled with a constant value.

Parameters:
  • n_prototypes (int)

  • n_features (int)

  • value (float) – Fill value.

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.selection_init(X, n_prototypes, key)[source]

Initialize prototypes by uniformly sampling from data (classless).

Suitable for unsupervised models like Neural Gas, SOM, and fuzzy clustering where no labels are available.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • n_prototypes (int) – Number of prototypes to select.

  • key (jax.random.PRNGKey)

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.mean_init(X, n_prototypes)[source]

Initialize all prototypes at the data mean (classless).

Suitable for unsupervised models. All prototypes start at the global mean and diverge during training.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • n_prototypes (int) – Number of prototypes.

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.literal_init(values)[source]

Initialize prototypes from literal values.

Used for warm-starting from another model’s prototypes or from user-provided values.

Parameters:

values (array-like of shape (n_prototypes, n_features)) – Literal prototype values.

Return type:

array of shape (n_prototypes, n_features)

prosemble.core.initializers.stratified_noise_init(X, y, n_per_class, key, noise_std=0.1)[source]

Initialize prototypes by selecting samples per class and adding noise.

Combines stratified selection with Gaussian noise injection for diverse initial prototypes.

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,))

  • n_per_class (int, list, or dict) – Number of prototypes per class (same formats as stratified_selection_init).

  • key (jax.random.PRNGKey)

  • noise_std (float) – Standard deviation of Gaussian noise to add.

Returns:

  • prototypes (array of shape (n_prototypes, n_features))

  • prototype_labels (array of shape (n_prototypes,))

prosemble.core.initializers.pca_omega_init(X, n_dims)[source]

Initialize omega using PCA directions from training data.

The top-n_dims principal components become the columns of omega, providing a data-driven initialization for metric learning models.

Parameters:
  • X (array of shape (n_samples, n_features)) – Training data.

  • n_dims (int) – Number of principal components (projection dimensionality).

Returns:

omega

Return type:

array of shape (n_features, n_dims)

prosemble.core.initializers.class_conditional_mean_init(X, y, n_per_class)[source]

Initialize prototypes at class means, replicated per n_per_class.

When n_per_class > 1, each class gets multiple prototypes all initialized at the class mean (they will diverge during training).

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,))

  • n_per_class (int, list, or dict) – Number of prototypes per class.

Returns:

  • prototypes (array of shape (n_prototypes, n_features))

  • prototype_labels (array of shape (n_prototypes,))

prosemble.core.initializers.random_reasonings_init(n_components, n_classes, key)[source]

Initialize CBC reasoning matrices randomly.

Parameters:
  • n_components (int)

  • n_classes (int)

  • key (jax.random.PRNGKey)

Returns:

reasonings

Return type:

array of shape (n_components, n_classes, 2)

prosemble.core.initializers.pure_positive_reasonings_init(n_components, n_classes, key=None)[source]

Initialize CBC reasoning matrices with pure positive evidence.

Each component maps to exactly one class with high positive evidence and low negative evidence.

Parameters:
  • n_components (int)

  • n_classes (int)

  • key (ignored (included for API compatibility))

Returns:

reasonings

Return type:

array of shape (n_components, n_classes, 2)

Pooling

Stratified pooling operations for prototype-based learning.

These functions aggregate per-prototype distances into per-class distances, grouping by prototype label. Essential for GLVQ and CELVQ.

prosemble.core.pooling.stratified_min_pooling(distances, prototype_labels, n_classes)[source]

Per-class minimum distance pooling.

For each sample and each class, returns the minimum distance to any prototype of that class.

Parameters:
  • distances (array of shape (n_samples, n_prototypes)) – Distance matrix.

  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • n_classes (int) – Number of classes.

Returns:

Minimum distance to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

prosemble.core.pooling.stratified_sum_pooling(distances, prototype_labels, n_classes)[source]

Per-class sum distance pooling.

Parameters:
  • distances (array of shape (n_samples, n_prototypes))

  • prototype_labels (array of shape (n_prototypes,))

  • n_classes (int)

Returns:

Sum of distances to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

prosemble.core.pooling.stratified_max_pooling(distances, prototype_labels, n_classes)[source]

Per-class maximum distance pooling.

Parameters:
  • distances (array of shape (n_samples, n_prototypes))

  • prototype_labels (array of shape (n_prototypes,))

  • n_classes (int)

Returns:

Maximum distance to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

prosemble.core.pooling.stratified_prod_pooling(distances, prototype_labels, n_classes)[source]

Per-class product distance pooling.

Uses log-sum-exp for numerical stability.

Parameters:
  • distances (array of shape (n_samples, n_prototypes))

  • prototype_labels (array of shape (n_prototypes,))

  • n_classes (int)

Returns:

Product of distances to each class for each sample.

Return type:

array of shape (n_samples, n_classes)

Quantization

Shared mixins for model base classes.

class prosemble.core.quantization.MetadataCollectorMixin[source]

Mixin that auto-collects _hyperparams and _fitted_array_names from MRO.

Any base class that declares per-class _hyperparams and _fitted_array_names tuples can inherit this mixin to get _all_hyperparams and _all_fitted_array_names aggregated automatically across the entire class hierarchy.

get_params(deep=True)[source]

Get parameters for this estimator.

Follows the sklearn estimator protocol by inspecting __init__ signatures across the MRO.

Parameters:

deep (bool, default=True) – Ignored (present for sklearn compatibility).

Returns:

Parameter names mapped to their values.

Return type:

dict

set_params(**params)[source]

Set parameters on this estimator.

Parameters:

**params – Estimator parameters to set.

Return type:

self

class prosemble.core.quantization.QuantizationMixin[source]

Mixin for quantizing/dequantizing fitted model parameters.

Supports float16, bfloat16, and int8 (with per-tensor scale factors).

Subclasses override _get_quantizable_attrs to declare which fitted attributes are eligible for quantization.

quantize(dtype='float16')[source]

Quantize model parameters to lower precision.

Post-training quantization for smaller model size and faster inference.

Parameters:

dtype (str) – Target precision: ‘float16’, ‘bfloat16’, or ‘int8’.

Return type:

self

dequantize()[source]

Restore model parameters to float32.

Return type:

self

property is_quantized: bool

Whether model parameters are currently quantized.

property quantized_dtype: str | None

Current quantization dtype, or None if not quantized.

Utilities

JAX utility functions for data preprocessing and manipulation.

Replaces common sklearn/numpy operations with JAX-native implementations.

prosemble.core.utils.train_test_split_jax(X, y=None, test_size=0.2, random_seed=42)[source]

JAX-native train/test split.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Feature matrix

  • y (array-like of shape (n_samples,), optional) – Labels

  • test_size (float, default=0.2) – Proportion of dataset for test set (0.0 to 1.0)

  • random_seed (int, default=42) – Random seed for reproducibility

Returns:

Split arrays. If y is provided, returns 4 arrays, else 2.

Return type:

X_train, X_test[, y_train, y_test]

Examples

>>> X = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> y = jnp.array([0, 1, 0, 1])
>>> X_train, X_test, y_train, y_test = train_test_split_jax(X, y, test_size=0.5)
prosemble.core.utils.standardize(X, mean=None, std=None)[source]

Standardize features (zero mean, unit variance).

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Data to standardize

  • mean (array-like of shape (n_features,), optional) – Pre-computed mean (for test data)

  • std (array-like of shape (n_features,), optional) – Pre-computed std (for test data)

Returns:

  • X_scaled (array-like) – Standardized data

  • mean (array-like) – Mean used for scaling

  • std (array-like) – Std used for scaling

Return type:

tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]

Examples

>>> X_train = jnp.array([[1, 2], [3, 4], [5, 6]])
>>> X_scaled, mean, std = standardize(X_train)
>>> # For test data
>>> X_test_scaled, _, _ = standardize(X_test, mean=mean, std=std)
prosemble.core.utils.min_max_scale(X, min_val=None, max_val=None)[source]

Scale features to [0, 1] range.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Data to scale

  • min_val (array-like of shape (n_features,), optional) – Pre-computed min (for test data)

  • max_val (array-like of shape (n_features,), optional) – Pre-computed max (for test data)

Returns:

  • X_scaled (array-like) – Scaled data

  • min_val (array-like) – Min values used

  • max_val (array-like) – Max values used

Return type:

tuple[Array | ndarray | bool | number, Array | ndarray | bool | number, Array | ndarray | bool | number]

prosemble.core.utils.pca_jax(X, n_components=2)[source]

Principal Component Analysis using JAX.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Data matrix

  • n_components (int, default=2) – Number of principal components

Returns:

  • X_transformed (array-like of shape (n_samples, n_components)) – Transformed data

  • components (array-like of shape (n_components, n_features)) – Principal components

Return type:

tuple[Array | ndarray | bool | number, Array | ndarray | bool | number]

Examples

>>> X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> X_pca, components = pca_jax(X, n_components=2)
prosemble.core.utils.accuracy_score_jax(y_true, y_pred)[source]

Compute classification accuracy.

Parameters:
  • y_true (array-like of shape (n_samples,)) – True labels

  • y_pred (array-like of shape (n_samples,)) – Predicted labels

Returns:

accuracy – Accuracy score

Return type:

float

Examples

>>> y_true = jnp.array([0, 1, 1, 0])
>>> y_pred = jnp.array([0, 1, 0, 0])
>>> accuracy_score_jax(y_true, y_pred)
0.75
prosemble.core.utils.confusion_matrix_jax(y_true, y_pred, n_classes)[source]

Compute confusion matrix.

Parameters:
  • y_true (array-like of shape (n_samples,)) – True labels

  • y_pred (array-like of shape (n_samples,)) – Predicted labels

  • n_classes (int) – Number of classes

Returns:

conf_matrix – Confusion matrix

Return type:

array-like of shape (n_classes, n_classes)

Examples

>>> y_true = jnp.array([0, 1, 2, 0, 1, 2])
>>> y_pred = jnp.array([0, 2, 2, 0, 0, 1])
>>> confusion_matrix_jax(y_true, y_pred, n_classes=3)
prosemble.core.utils.shuffle_jax(X, y=None, random_seed=42)[source]

Shuffle arrays in unison.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Feature matrix

  • y (array-like of shape (n_samples,), optional) – Labels

  • random_seed (int, default=42) – Random seed

Returns:

Shuffled arrays

Return type:

X_shuffled[, y_shuffled]

Examples

>>> X = jnp.array([[1, 2], [3, 4], [5, 6]])
>>> y = jnp.array([0, 1, 0])
>>> X_shuf, y_shuf = shuffle_jax(X, y, random_seed=42)
prosemble.core.utils.k_fold_split_jax(n_samples, n_folds=5, random_seed=42)[source]

Generate K-fold cross-validation indices.

Parameters:
  • n_samples (int) – Number of samples

  • n_folds (int, default=5) – Number of folds

  • random_seed (int, default=42) – Random seed

Yields:

train_indices, test_indices – Indices for each fold

Examples

>>> for train_idx, test_idx in k_fold_split_jax(100, n_folds=5):
...     X_train, X_test = X[train_idx], X[test_idx]
prosemble.core.utils.orthogonalize(matrix)[source]

Orthogonalize a matrix via polar decomposition (SVD).

Given a matrix A of shape (d, s), computes the closest orthogonal matrix Q such that Q^T Q = I, using the polar factor:

U, S, V^T = SVD(A) Q = U @ V^T

Supports batched input via jax.vmap.

Parameters:

matrix (array of shape (d, s) or (n, d, s)) – Matrix or batch of matrices to orthogonalize. For batched input, use jax.vmap(orthogonalize).

Returns:

Q – Orthogonal matrix (columns are orthonormal).

Return type:

array of same shape as input

prosemble.core.utils.class_priors(labels, n_classes=None)[source]

Compute class prior probabilities from labels.

P(class=k) = n_k / n

Parameters:
  • labels (array of shape (n_samples,)) – Integer class labels.

  • n_classes (int, optional) – Number of classes. Inferred from labels if not provided.

Returns:

priors – Prior probability for each class, sums to 1.

Return type:

array of shape (n_classes,)

prosemble.core.utils.prototype_priors(prototype_labels, n_classes=None)[source]

Compute class priors from prototype label distribution.

Used in probabilistic LVQ models where the prior over prototypes is uniform (1/n_prototypes) and the class prior is the fraction of prototypes assigned to each class.

P(class=k) = |{j : label_j = k}| / n_prototypes

Parameters:
  • prototype_labels (array of shape (n_prototypes,)) – Class label for each prototype.

  • n_classes (int, optional) – Number of classes. Inferred if not provided.

Returns:

priors – Class prior probabilities, sums to 1.

Return type:

array of shape (n_classes,)

prosemble.core.utils.uniform_prototype_prior(n_prototypes)[source]

Uniform prior over prototypes: P(prototype_j) = 1/n.

This is the standard prior used in probabilistic LVQ (SLVQ/RSLVQ).

Parameters:

n_prototypes (int) – Number of prototypes.

Returns:

prior – Uniform probability vector, sums to 1.

Return type:

array of shape (n_prototypes,)

Pipeline

class prosemble.core.pipeline.NotFittedError[source]

Raised when calling transform/predict on an unfitted estimator.

class prosemble.core.pipeline.TransformerMixin[source]

Mixin providing fit_transform.

fit_transform(X, y=None)[source]

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.StandardScaler[source]

Standardize features to zero mean and unit variance.

Wraps prosemble.core.utils.standardize().

Examples

>>> scaler = StandardScaler()
>>> X_scaled = scaler.fit_transform(X_train)
>>> X_test_scaled = scaler.transform(X_test)
fit(X, y=None)[source]

Compute mean and std from training data.

Return type:

Self

transform(X)[source]

Standardize using stored mean and std.

Return type:

Array | ndarray | bool | number

get_params(deep=True)[source]
set_params(**params)[source]
fit_transform(X, y=None)

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.MinMaxScaler[source]

Scale features to [0, 1] range.

Wraps prosemble.core.utils.min_max_scale().

Examples

>>> scaler = MinMaxScaler()
>>> X_scaled = scaler.fit_transform(X_train)
fit(X, y=None)[source]

Compute min and max from training data.

Return type:

Self

transform(X)[source]

Scale using stored min and max.

Return type:

Array | ndarray | bool | number

get_params(deep=True)[source]
set_params(**params)[source]
fit_transform(X, y=None)

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.PCA(n_components=2)[source]

Principal Component Analysis.

Wraps prosemble.core.utils.pca_jax().

Parameters:

n_components (int, default=2) – Number of principal components to keep.

Examples

>>> pca = PCA(n_components=2)
>>> X_reduced = pca.fit_transform(X)
fit(X, y=None)[source]

Compute principal components from training data.

Return type:

Self

transform(X)[source]

Project data onto principal components.

Return type:

Array | ndarray | bool | number

get_params(deep=True)[source]
set_params(**params)[source]
fit_transform(X, y=None)

Fit and transform in one step.

Parameters:
Return type:

Array | ndarray | bool | number

class prosemble.core.pipeline.Pipeline(steps)[source]

Chain transformers with a final estimator.

All operations use pure JAX arrays — no numpy roundtrips.

Parameters:

steps (list of (name, estimator) tuples) – Sequence of transforms with a final estimator. All but the last must implement fit() and transform(). The last step can be any estimator or transformer.

Examples

>>> pipe = Pipeline([
...     ('scaler', StandardScaler()),
...     ('pca', PCA(n_components=2)),
...     ('model', GLVQ(n_prototypes_per_class=1, max_iter=50)),
... ])
>>> pipe.fit(X_train, y_train)
>>> preds = pipe.predict(X_test)
fit(X, y=None, **fit_params)[source]

Fit all steps.

Calls fit_transform on intermediate steps, fit on the final step.

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,), optional)

  • **fit_params (forwarded to the final estimator's fit().)

Return type:

Self

predict(X)[source]

Transform X through pipeline, then predict.

predict_proba(X)[source]

Transform X through pipeline, then predict_proba.

transform(X)[source]

Transform X through all steps (including last if it has transform).

fit_transform(X, y=None, **fit_params)[source]

Fit and transform.

get_params(deep=True)[source]

Get pipeline parameters.

If deep=True, includes nested estimator params as name__param.

set_params(**params)[source]

Set pipeline parameters.

Supports nested params: model__lr=0.05 sets lr on step ‘model’.

Model Selection

prosemble.core.model_selection.clone(estimator, **override_params)[source]

Create a fresh (unfitted) copy of an estimator.

Parameters:
  • estimator (object) – Estimator with get_params() protocol.

  • **override_params – Parameters to override in the clone.

Returns:

new_estimator – Fresh, unfitted instance.

Return type:

same type as estimator

Examples

>>> model = GLVQ(n_prototypes_per_class=2, lr=0.01)
>>> model2 = clone(model, lr=0.05)
prosemble.core.model_selection.cross_val_score(estimator, X, y=None, cv=5, scoring='accuracy', random_seed=42)[source]

Evaluate estimator with cross-validation.

Parameters:
  • estimator (object) – Estimator with fit/predict and get_params.

  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,), optional) – Required for supervised estimators.

  • cv (int, default=5) – Number of cross-validation folds.

  • scoring (str or callable, default='accuracy') – ‘accuracy’ or callable scorer(y_true, y_pred) -> float. For unsupervised without y: scorer(estimator, X_test) -> float.

  • random_seed (int, default=42)

Returns:

scores – Score for each fold.

Return type:

jnp.ndarray of shape (cv,)

Examples

>>> scores = cross_val_score(GLVQ(max_iter=30), X, y, cv=5)
>>> print(f"Mean: {scores.mean():.3f} +/- {scores.std():.3f}")
class prosemble.core.model_selection.GridSearchCV(estimator, param_grid, cv=5, scoring='accuracy', random_seed=42, refit=True, verbose=0)[source]

Exhaustive search over a parameter grid with cross-validation.

Parameters:
  • estimator (object) – Base estimator with fit/predict and get_params/set_params. Can be a Pipeline or any prosemble model.

  • param_grid (dict) – Maps parameter names to lists of values to try. For Pipeline steps, use step_name__param notation.

  • cv (int, default=5) – Number of cross-validation folds.

  • scoring (str or callable, default='accuracy') – ‘accuracy’ or callable scorer(y_true, y_pred) -> float.

  • random_seed (int, default=42)

  • refit (bool, default=True) – If True, refit the best model on the full dataset after search.

  • verbose (int, default=0) – 0=silent, 1=per-combo summary, 2=per-fold detail.

best_params_

Parameters of the best model.

Type:

dict

best_score_

Mean CV score of the best model.

Type:

float

best_estimator_

Fitted estimator with best params (only if refit=True).

Type:

object

cv_results_

Keys: ‘params’, ‘mean_score’, ‘std_score’, ‘fold_scores’, ‘rank’.

Type:

dict

Examples

>>> gs = GridSearchCV(
...     GLVQ(max_iter=30),
...     {'n_prototypes_per_class': [1, 2], 'lr': [0.01, 0.05]},
...     cv=3,
... )
>>> gs.fit(X, y)
>>> print(gs.best_params_, gs.best_score_)
fit(X, y=None)[source]

Run grid search with cross-validation.

Parameters:
  • X (array of shape (n_samples, n_features))

  • y (array of shape (n_samples,), optional)

Return type:

self

predict(X)[source]

Predict using best estimator.

predict_proba(X)[source]

Predict probabilities using best estimator.

Datasets

JAX-compatible dataset module.

Provides dataset loaders that return JAX arrays instead of NumPy arrays, optimized for use with JAX-based clustering models.

class prosemble.datasets.dataset.DATASET_JAX(input_data, labels)[source]

JAX-compatible dataset container.

Parameters:
input_data

Feature matrix as JAX array

Type:

jax.numpy.ndarray

labels

Labels as JAX array

Type:

jax.numpy.ndarray

input_data: Array
labels: Array
to_numpy()[source]

Convert JAX arrays back to NumPy arrays.

prosemble.datasets.dataset.iris_dataset_jax(dtype=None)[source]

Load Iris dataset as JAX arrays.

Parameters:

dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (150 samples, 4 features, 3 classes)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import iris_dataset_jax
>>> dataset = iris_dataset_jax()
>>> dataset.input_data.shape
(150, 4)
prosemble.datasets.dataset.breast_cancer_dataset_jax(dtype=None)[source]

Load Wisconsin Breast Cancer dataset as JAX arrays.

Parameters:

dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (569 samples, 30 features, 2 classes)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import breast_cancer_dataset_jax
>>> dataset = breast_cancer_dataset_jax()
>>> dataset.input_data.shape
(569, 30)
prosemble.datasets.dataset.moons_dataset_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]

Generate two interleaving half circles (moons) as JAX arrays.

Parameters:
  • n_samples (int, default=150) – Total number of points generated

  • noise (float, optional) – Standard deviation of Gaussian noise added to data

  • random_state (int, optional) – Random seed for reproducibility

  • dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (n_samples, 2 features, 2 classes)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import moons_dataset_jax
>>> dataset = moons_dataset_jax(n_samples=200, noise=0.1, random_state=42)
>>> dataset.input_data.shape
(200, 2)
prosemble.datasets.dataset.blobs_dataset_jax(n_samples=None, centers=None, cluster_std=None, random_state=None, dtype=None)[source]

Generate isotropic Gaussian blobs as JAX arrays.

Parameters:
  • n_samples (list, default=[120, 80]) – Number of samples per cluster

  • centers (list, default=[[0.0, 0.0], [2.0, 2.0]]) – Centers of the clusters

  • cluster_std (list, default=[1.2, 0.5]) – Standard deviation of clusters

  • random_state (int, optional) – Random seed for reproducibility

  • dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Returns:

Dataset with JAX arrays (sum(n_samples), 2 features)

Return type:

DATASET_JAX

Examples

>>> from prosemble.datasets import blobs_dataset_jax
>>> dataset = blobs_dataset_jax(random_state=42)
>>> dataset.input_data.shape
(200, 2)
class prosemble.datasets.dataset.DATA_JAX(random=4, sample_size=1000)[source]

JAX dataset collection with convenient property access.

Parameters:
  • random (int, default=4) – Random seed for dataset generation

  • sample_size (int, default=1000) – Default sample size (not currently used)

  • dtype (jax.numpy.dtype, default=jnp.float32) – Data type for features

Examples

>>> from prosemble.datasets import DATA_JAX
>>> data = DATA_JAX(random=42)
>>> moons = data.S_1  # Moons dataset
>>> blobs = data.S_2  # Blobs dataset
>>> cancer = data.breast_cancer  # Breast cancer dataset
random: int = 4
sample_size: int = 1000
dtype

alias of float32

property S_1: DATASET_JAX

Moons dataset.

property S_2: DATASET_JAX

Blobs dataset.

property iris: DATASET_JAX

Iris dataset.

property breast_cancer: DATASET_JAX

Breast cancer dataset.

property moons: DATASET_JAX

Alias for S_1 (moons dataset).

property blobs: DATASET_JAX

Alias for S_2 (blobs dataset).

prosemble.datasets.dataset.load_iris_jax(dtype=None)[source]

Quick loader for iris dataset.

Return type:

DATASET_JAX

prosemble.datasets.dataset.load_breast_cancer_jax(dtype=None)[source]

Quick loader for breast cancer dataset.

Return type:

DATASET_JAX

prosemble.datasets.dataset.load_moons_jax(n_samples=150, noise=None, random_state=None, dtype=None)[source]

Quick loader for moons dataset.

Parameters:
  • n_samples (int)

  • noise (float | None)

  • random_state (int | None)

Return type:

DATASET_JAX

prosemble.datasets.dataset.load_blobs_jax(random_state=None, dtype=None)[source]

Quick loader for blobs dataset.

Parameters:

random_state (int | None)

Return type:

DATASET_JAX

prosemble.datasets.dataset.DATA

alias of DATA_JAX

ONNX Export

ONNX export for prosemble prototype-based models.

Converts a fitted model’s predict function into an ONNX graph. Only supports models whose distance function can be expressed with standard ONNX operators. Unsupported models raise NotImplementedError with a clear message.

Supported distance functions:

  • squared_euclidean_distance_matrix

  • euclidean_distance_matrix

  • manhattan_distance_matrix

  • omega_distance_matrix (global projection matrix)

  • lomega_distance_matrix (per-prototype local matrices)

  • tangent_distance_matrix (per-prototype tangent subspace)

  • relevance_weighted (per-feature relevance weighting)

Supported decision patterns:

  • WTAC (supervised classification)

  • ArgMin (unsupervised clustering)

  • One-class hard nearest (OCGLVQ family)

  • One-class Gaussian soft (OCRSLVQ family)

  • One-class Gaussian+NG soft (OCRSLVQ_NG family)

  • SVQ-OCC response model (SVQOCC family)

  • CBC reasoning (CBC, ImageCBC)

  • PLVQ Gaussian mixture soft assignment

Supported encoder models:

  • MLP encoder (SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN, PLVQ)

  • CNN encoder (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC)

Not supported:

  • gaussian_kernel_matrix, polynomial_kernel_matrix

  • Riemannian manifold distances (logm, expm have no ONNX equivalent)

prosemble.core.onnx_export.export_onnx(model, batch_size=1, opset_version=17, path=None)[source]

Export a fitted model’s predict function to ONNX format.

Builds an ONNX graph that reproduces the model’s predict() output. Supports supervised (WTAC), unsupervised (ArgMin), one-class (threshold-based), SVQ-OCC (response model), CBC (reasoning matrices), PLVQ (Gaussian mixture), encoder models (MLP/CNN backbone), and Riemannian models on SO(n)/Grassmannian manifolds (73 of 87 models total).

Parameters:
  • model (SupervisedPrototypeModel or UnsupervisedPrototypeModel) – A fitted prosemble model.

  • batch_size (int) – Fixed batch dimension for the input. Use -1 for dynamic batch size (ONNX symbolic dimension).

  • opset_version (int) – ONNX opset version. Default: 17.

  • path (str, optional) – If provided, save the ONNX model to this file path.

Returns:

The exported ONNX model.

Return type:

onnx.ModelProto

Raises:
  • NotImplementedError – If the model’s distance function or decision pattern is not supported.

  • ImportError – If the onnx package is not installed.

Manifolds

Riemannian manifold primitives for prototype learning.

Provides geodesic distance, exponential map, and logarithmic map for manifolds commonly arising in machine learning:

  • SO(n): Special orthogonal group (rotation matrices)

  • SPD(n): Symmetric positive definite matrices

  • Grassmannian Gr(n, k): k-dimensional subspaces of R^n

All operations are JIT-compilable and vmap-compatible.

References

Schwarz, Psenickova, Villmann, Röhrbein (2026). Topology-Preserving Prototype Learning on Riemannian Manifolds. ESANN 2026.

prosemble.core.manifolds.logm_safe(A)[source]

Matrix logarithm via funm, with numerical safety.

Uses JAX’s funm to compute logm. For complex eigenvalues (e.g. rotation matrices), operates in complex domain and returns the real part.

Parameters:

A (array of shape (..., n, n))

Returns:

logA

Return type:

array of shape (…, n, n)

prosemble.core.manifolds.sqrt_spd(A)[source]

Matrix square root for symmetric positive definite matrices.

Uses eigendecomposition: \(A^{1/2} = V \operatorname{diag}(\sqrt{\lambda}) V^T\).

Parameters:

A (array of shape (n, n), symmetric positive definite)

Returns:

A_sqrt

Return type:

array of shape (n, n)

prosemble.core.manifolds.inv_sqrt_spd(A)[source]

Inverse matrix square root for symmetric positive definite matrices.

Uses eigendecomposition: \(A^{-1/2} = V \operatorname{diag}(1/\sqrt{\lambda}) V^T\).

Parameters:

A (array of shape (n, n), symmetric positive definite)

Returns:

A_inv_sqrt

Return type:

array of shape (n, n)

class prosemble.core.manifolds.SO(n)[source]

Special orthogonal group SO(n): rotation matrices.

Points are \(n imes n\) orthogonal matrices with det = +1. Geodesic distance uses the bi-invariant metric.

Parameters:

n (int) – Dimension of the rotation group.

distance(R, S)[source]

Geodesic distance: d(R, S) = ||logm(R^T S)||_F / sqrt(2).

distance_squared(R, S)[source]

Squared geodesic distance.

log_map(R, S)[source]

Logarithmic map: Log_R(S) = R @ logm(R^T @ S).

Maps point S on the manifold to a tangent vector at R.

exp_map(R, V)[source]

Exponential map: Exp_R(V) = R @ expm(R^T @ V).

Maps tangent vector V at R back to the manifold.

random_point(key)[source]

Generate a random rotation matrix via QR decomposition.

Flips only the last column to ensure det = +1, preserving the Haar-uniform distribution on SO(n).

belongs(R)[source]

Check if R is in SO(n): \(R^T R \approx I\) and \(\det(R) \approx +1\).

project(R)[source]

Project to nearest rotation matrix via polar decomposition.

injectivity_radius(R)[source]

Injectivity radius of SO(n) is \(\pi\).

class prosemble.core.manifolds.SPD(n)[source]

Manifold of \(n \times n\) symmetric positive definite matrices.

Uses the affine-invariant Riemannian metric.

Parameters:

n (int) – Matrix dimension.

distance(A, B)[source]

Geodesic distance: d(A, B) = ||logm(A^{-1/2} B A^{-1/2})||_F.

distance_squared(A, B)[source]

Squared geodesic distance.

log_map(A, B)[source]

Log map: Log_A(B) = A^{1/2} logm(A^{-1/2} B A^{-1/2}) A^{1/2}.

exp_map(A, V)[source]

Exp map: Exp_A(V) = A^{1/2} expm(A^{-1/2} V A^{-1/2}) A^{1/2}.

random_point(key)[source]

Generate random SPD matrix: \(A = L L^T + \epsilon I\).

belongs(A)[source]

Check if A is SPD: symmetric and all eigenvalues > 0.

project(A)[source]

Project to nearest SPD: symmetrize and clamp eigenvalues.

injectivity_radius(A)[source]

SPD manifold has infinite injectivity radius.

class prosemble.core.manifolds.Grassmannian(n, k)[source]

Grassmannian manifold Gr(n, k): k-dimensional subspaces of R^n.

Points are represented as orthonormal bases Q of shape (n, k) with Q^T Q = I_k.

Parameters:
  • n (int) – Ambient dimension.

  • k (int) – Subspace dimension.

distance(Q1, Q2)[source]

Geodesic distance via principal angles.

d(Q1, Q2) = ||theta||_2 where theta = arccos(svd(Q1^T Q2)).

distance_squared(Q1, Q2)[source]

Squared geodesic distance.

log_map(Q1, Q2)[source]

Logarithmic map on the Grassmannian.

Computes the tangent vector at Q1 pointing toward Q2 using aligned principal angle decomposition (Edelman et al. 1998).

The key insight is to align both subspaces via the SVD of Q1^T Q2, ensuring the principal angles and directions correspond.

exp_map(Q, V)[source]

Exponential map on the Grassmannian.

Maps tangent vector V at Q back to the manifold.

random_point(key)[source]

Generate a random point on Gr(n, k) via QR decomposition.

belongs(Q)[source]

Check if Q represents a point on Gr(n, k): \(Q^T Q \approx I_k\).

project(Q)[source]

Project to nearest orthonormal basis via QR.

injectivity_radius(Q)[source]

Injectivity radius of Gr(n,k) is \(\pi/2\).

Protocols

Structural typing protocols for prosemble interfaces.

Defines typing.Protocol contracts for duck-typed interfaces used across the library. These enable static type checking (mypy / pyright) and IDE auto-completion without requiring inheritance.

Note

Runtime-checkable protocols (Manifold, CallbackLike) support isinstance() checks. Type aliases (DistanceMatrixFn, etc.) are for annotation only.

class prosemble.core.protocols.Manifold(*args, **kwargs)[source]

Protocol for Riemannian manifold implementations.

Any object exposing the methods below can be used wherever a manifold is expected (e.g. RiemannianNeuralGas, RiemannianSRNG).

The concrete implementations SO, SPD, and Grassmannian all satisfy this protocol structurally — no explicit subclassing is required.

property point_shape: Tuple[int, ...]

Shape of a single point on the manifold.

distance(p, q)[source]

Geodesic distance between two points.

Parameters:
  • p (arrays of shape point_shape)

  • q (arrays of shape point_shape)

Return type:

scalar

distance_squared(p, q)[source]

Squared geodesic distance between two points.

Parameters:
  • p (arrays of shape point_shape)

  • q (arrays of shape point_shape)

Return type:

scalar

log_map(base, target)[source]

Logarithmic map: tangent vector at base pointing toward target.

Parameters:
  • base (array of shape point_shape) – Base point on the manifold.

  • target (array of shape point_shape) – Target point on the manifold.

Returns:

tangent – Tangent vector in \(T_{\text{base}} M\).

Return type:

array of shape point_shape

exp_map(base, tangent)[source]

Exponential map: move along tangent from base back to the manifold.

Parameters:
  • base (array of shape point_shape)

  • tangent (array of shape point_shape)

Returns:

point

Return type:

array of shape point_shape

random_point(key)[source]

Sample a random point on the manifold.

Parameters:

key (JAX PRNG key)

Returns:

point

Return type:

array of shape point_shape

belongs(point)[source]

Check whether point lies on the manifold.

Parameters:

point (array of shape point_shape)

Return type:

bool or bool-valued array

project(point)[source]

Project an off-manifold point to the nearest point on the manifold.

Parameters:

point (array of shape point_shape)

Returns:

projected

Return type:

array of shape point_shape

injectivity_radius(point)[source]

Injectivity radius at point.

The maximum geodesic distance for which the logarithmic map is injective.

Parameters:

point (array of shape point_shape)

Returns:

radius

Return type:

float or scalar array

class prosemble.core.protocols.CallbackLike(*args, **kwargs)[source]

Protocol for training callbacks.

Any object with the three hook methods below can be passed in the callbacks list of a model’s constructor. The existing Callback base class already satisfies this protocol.

on_fit_start(model, X)[source]

Called once before training begins.

Parameters:
Return type:

None

on_iteration_end(model, info)[source]

Called after each training iteration / epoch.

Parameters:
Return type:

None

on_fit_end(model, info)[source]

Called once after training ends.

Parameters:
Return type:

None

prosemble.core.protocols.DistanceMatrixFn

(X, Y) -> distances. X has shape (n_samples, n_features), Y has shape (n_prototypes, n_features), result has shape (n_samples, n_prototypes).

Type:

Distance-matrix function

alias of Callable[[Array, Array], Array]

prosemble.core.protocols.DistancePairwiseFn

(x, y) -> scalar.

Type:

Pairwise distance function

alias of Callable[[Array, Array], Array]

prosemble.core.protocols.SupervisedInitFn

Supervised prototype initializer: (X, y, n_per_class, key) -> (prototypes, prototype_labels).

alias of Callable[[Array, Array, int, Array], Tuple[Array, Array]]

prosemble.core.protocols.UnsupervisedInitFn

Unsupervised prototype initializer: (X, n_prototypes, key) -> prototypes.

alias of Callable[[Array, int, Array], Array]

Distributed Training

Distributed training utilities for multi-device data parallelism.

Provides functions for sharding data across devices and replicating model parameters, enabling data-parallel training on multi-GPU/TPU setups.

prosemble.core.distributed.create_mesh(devices=None)[source]

Create a 1D device mesh for data parallelism.

Parameters:

devices (list of jax.Device or None) – Devices to use. If None, returns None (single-device mode).

Return type:

Mesh or None

prosemble.core.distributed.shard_data(X, y, mesh)[source]

Shard data arrays along the batch dimension across devices.

Parameters:
  • X (jnp.ndarray of shape (n_samples, ...)) – Input data.

  • y (jnp.ndarray of shape (n_samples,)) – Labels.

  • mesh (Mesh) – Device mesh from create_mesh().

Returns:

X_sharded, y_sharded

Return type:

tuple of sharded arrays

prosemble.core.distributed.replicate_params(params, mesh)[source]

Replicate params across all devices (no partitioning).

Parameters:
  • params (dict (pytree)) – Model parameters.

  • mesh (Mesh) – Device mesh from create_mesh().

Return type:

params replicated across mesh

prosemble.core.distributed.replicate_opt_state(opt_state, mesh)[source]

Replicate optimizer state across all devices.

Parameters:
  • opt_state (pytree) – Optax optimizer state.

  • mesh (Mesh) – Device mesh from create_mesh().

Return type:

opt_state replicated across mesh

prosemble.core.distributed.unshard_params(params)[source]

Bring params back to a single device.

Used after training to store results as plain arrays for predict/export operations.

Parameters:

params (pytree) – Potentially sharded parameters.

Return type:

params on default device

Data Utilities

Data loading and batching utilities for prosemble.

Provides composable primitives for mini-batch iteration, suitable for custom training loops. The built-in fit() methods already handle batching internally — these utilities are for advanced users who need explicit control over data iteration.

prosemble.core.data.shuffle_arrays(key, *arrays)[source]

Shuffle multiple arrays with the same random permutation.

Parameters:
  • key (JAX PRNG key) – Random key for generating the permutation.

  • *arrays (jnp.ndarray) – Arrays to shuffle. All must have the same length along axis 0.

Returns:

Shuffled arrays in the same order as the inputs.

Return type:

tuple of jnp.ndarray

Examples

>>> key = jax.random.PRNGKey(0)
>>> X = jnp.arange(12).reshape(4, 3)
>>> y = jnp.array([0, 1, 0, 1])
>>> X_s, y_s = shuffle_arrays(key, X, y)
prosemble.core.data.padded_batches(X, y=None, batch_size=32, key=None)[source]

Split data into static-shaped batches, padding the last batch.

Returns arrays of shape (n_batches, batch_size, ...) suitable for jax.lax.scan. If the data length is not divisible by batch_size, the last batch is padded by repeating initial samples.

Parameters:
  • X (jnp.ndarray of shape (n_samples, ...)) – Feature array.

  • y (jnp.ndarray of shape (n_samples,), optional) – Label array.

  • batch_size (int) – Number of samples per batch.

  • key (JAX PRNG key, optional) – If provided, data is shuffled before batching.

Returns:

  • X_batches (jnp.ndarray of shape (n_batches, batch_size, …))

  • y_batches (jnp.ndarray of shape (n_batches, batch_size) or None)

Return type:

tuple[Array, Array | None]

Examples

>>> X = jnp.ones((10, 3))
>>> y = jnp.arange(10)
>>> X_b, y_b = padded_batches(X, y, batch_size=4)
>>> X_b.shape  # (3, 4, 3) — 10 samples padded to 12
(3, 4, 3)
prosemble.core.data.batched_iterator(X, y=None, batch_size=32, shuffle=True, key=None, drop_last=False)[source]

Yield mini-batches from data arrays.

For use in custom Python training loops.

Parameters:
  • X (jnp.ndarray of shape (n_samples, ...)) – Feature array.

  • y (jnp.ndarray of shape (n_samples,), optional) – Label array.

  • batch_size (int) – Number of samples per batch.

  • shuffle (bool) – If True and key is provided, shuffle before iterating.

  • key (JAX PRNG key, optional) – Required if shuffle=True.

  • drop_last (bool) – If True, drop the last batch when it is smaller than batch_size. If False (default), the last batch may be smaller.

Yields:
  • X_batch (jnp.ndarray of shape (batch_size, …) or smaller)

  • y_batch (jnp.ndarray of shape (batch_size,) or None)

Examples

>>> key = jax.random.PRNGKey(0)
>>> X = jnp.ones((10, 3))
>>> for X_b, y_b in batched_iterator(X, batch_size=4, key=key):
...     print(X_b.shape)
(4, 3)
(4, 3)
(2, 3)