Models API Reference¶
All models are importable from prosemble.models.
Supervised LVQ¶
- class prosemble.models.GLVQ(beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
Generalized Learning Vector Quantization.
Loss:
\[\mu = \frac{d^+ - d^-}{d^+ + d^-}\]with optional transfer function.
- Parameters:
beta (float) – Parameter \(\beta\) for transfer function (e.g., sigmoid steepness).
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- classmethod load(path)¶
Load a fitted model from an NPZ file.
- Parameters:
path (str) – Path to the
.npzfile.- Returns:
Reconstructed fitted model.
- Return type:
model
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.GRLVQ(beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Generalized Relevance Learning Vector Quantization.
Learns per-feature relevance weights \(\lambda_j\) such that the weighted distance is:
\[d(x, w) = \sum_j \lambda_j (x_j - w_j)^2\]- Parameters:
beta (float) – Transfer function steepness parameter.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.GMLVQ(latent_dim=None, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Generalized Matrix Learning Vector Quantization.
Learns a global linear mapping \(\Omega\) (d x latent_dim) such that distances are computed in the transformed space:
\[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]The relevance matrix \(\Lambda = \Omega^T \Omega\) captures feature correlations.
- Parameters:
latent_dim (int, optional) – Dimensionality of the latent space. If None, uses input dim.
beta (float) – Transfer function steepness.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.LGMLVQ(latent_dim=None, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Localized Generalized Matrix Learning Vector Quantization.
Each prototype \(k\) has its own \(\Omega_k\) matrix. The distance from sample \(x\) to prototype \(w_k\) is:
\[d(x, w_k) = (x - w_k)^T \Omega_k^T \Omega_k (x - w_k)\]- Parameters:
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
beta (float) – Transfer function steepness.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.GTLVQ(subspace_dim=2, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Generalized Tangent Learning Vector Quantization.
Each prototype \(k\) has a subspace basis \(\Omega_k\). The tangent distance is:
\[d(x, w_k) = \|P_k(x - w_k)\|^2\]where \(P_k = I - \Omega_k \Omega_k^T\) is the orthogonal projector.
- Parameters:
subspace_dim (int) – Dimension of each prototype’s tangent subspace.
beta (float) – Transfer function steepness.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.CELVQ(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, gradient_checkpointing=False, devices=None)[source]¶
Cross-Entropy Learning Vector Quantization.
Computes per-class minimum distances, negates them to get logits, then applies cross-entropy loss against true labels.
- Parameters:
n_prototypes_per_class (int) – Prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float)
random_seed (int)
margin (float)
callbacks (list)
use_scan (bool)
batch_size (int | None)
lr_scheduler_kwargs (dict | None)
patience (int | None)
restore_best (bool)
gradient_accumulation_steps (int | None)
ema_decay (float | None)
freeze_params (list | None)
lookahead (dict | None)
mixed_precision (str | None)
gradient_checkpointing (bool)
devices (list | None)
See also
SupervisedPrototypeModelFull list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.LVQ1(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
Learning Vector Quantization 1.
For each sample:
Find nearest prototype (winner)
If same class: \(w \leftarrow w + \eta (x - w)\) (attract)
If diff class: \(w \leftarrow w - \eta (x - w)\) (repel)
Uses batch updates (all samples per iteration).
- Parameters:
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.LVQ21(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
Learning Vector Quantization 2.1.
For each sample:
Find closest same-class prototype \(w^+\) and closest different-class \(w^-\)
\(w^+ \leftarrow w^+ + \eta (x - w^+)\) (attract \(w^+\))
\(w^- \leftarrow w^- - \eta (x - w^-)\) (repel \(w^-\))
- Parameters:
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.MedianLVQ(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
Median Learning Vector Quantization.
Prototypes are always actual data points. The algorithm alternates:
E-step: compute soft assignments (GLVQ-like weights)
M-step: for each prototype, find the data point that minimizes the weighted sum of distances
- Parameters:
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.SLVQ(sigma=1.0, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
Soft Learning Vector Quantization.
Uses Gaussian mixture probabilities:
\[p(k|x) = \frac{\exp(-d^2 / 2\sigma^2)}{\sum_j \exp(-d_j^2 / 2\sigma^2)}\]\[P(\text{class}|x) = \sum_{k \in \text{class}} p(k|x)\]Loss: \(-\log(P(\text{correct}) / P(\text{wrong}))\)
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.RSLVQ(sigma=1.0, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
Robust Soft Learning Vector Quantization.
Like SLVQ but with a more robust denominator:
\[\text{loss} = -\log\frac{P(\text{correct}|x)}{P(\text{all}|x)}\]- Parameters:
sigma (float) – Bandwidth of Gaussian mixture.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.MRSLVQ(sigma=1.0, latent_dim=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Matrix Robust Soft Learning Vector Quantization.
Combines the RSLVQ probabilistic loss with a learned global linear mapping \(\Omega\) (d x latent_dim) for metric adaptation:
\[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]The relevance matrix \(\Lambda = \Omega^T \Omega\) captures feature correlations in the probabilistic framework.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture.
latent_dim (int, optional) – Dimensionality of the latent space. If None, uses input dim.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.LMRSLVQ(sigma=1.0, latent_dim=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Localized Matrix Robust Soft Learning Vector Quantization.
Each prototype \(k\) has its own \(\Omega_k\) matrix. The distance from sample \(x\) to prototype \(w_k\) is:
\[d(x, w_k) = (x - w_k)^T \Omega_k^T \Omega_k (x - w_k)\]Combined with the RSLVQ probabilistic loss for metric-adaptive soft classification with local relevance learning.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture.
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.RSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Robust Soft LVQ with Neural Gas Cooperation.
Combines:
RSLVQ probabilistic loss: \(-\log(P(\text{correct}|x))\)
Neural Gas cooperation: all prototypes weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Euclidean distance
The NG neighborhood modulates RSLVQ’s Gaussian probabilities, emphasizing nearby prototypes. \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
- Parameters:
sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1). Samples below this threshold are rejected (label -1).
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.MRSLVQ_NG(sigma=1.0, latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Matrix Robust Soft LVQ with Neural Gas Cooperation.
Combines:
RSLVQ probabilistic loss: \(-\log(P(\text{correct}|x))\)
Neural Gas cooperation: all prototypes weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Global \(\Omega\) matrix for metric adaptation:
\[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]
- Parameters:
sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.
latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1). Samples below this threshold are rejected (label -1).
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.LMRSLVQ_NG(sigma=1.0, latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Localized Matrix Robust Soft LVQ with Neural Gas Cooperation.
Each prototype \(k\) has its own \(\Omega_k\) matrix. Combined with RSLVQ probabilistic loss and NG rank-based neighborhood cooperation.
- Parameters:
sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1). Samples below this threshold are rejected (label -1).
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.GLVQ1(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, gradient_checkpointing=False, devices=None)[source]¶
GLVQ with LVQ1-style loss (gradient-based).
Loss: \(d^+\) when correct, \(-d^-\) when wrong.
- Parameters:
n_prototypes_per_class (int) – Prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float)
random_seed (int)
margin (float)
callbacks (list)
use_scan (bool)
batch_size (int | None)
lr_scheduler_kwargs (dict | None)
patience (int | None)
restore_best (bool)
gradient_accumulation_steps (int | None)
ema_decay (float | None)
freeze_params (list | None)
lookahead (dict | None)
mixed_precision (str | None)
gradient_checkpointing (bool)
devices (list | None)
See also
SupervisedPrototypeModelFull list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.GLVQ21(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, gradient_checkpointing=False, devices=None)[source]¶
GLVQ with LVQ2.1-style loss (gradient-based, unnormalized).
Loss: \(d^+ - d^-\) (no normalization by \(d^+ + d^-\)).
- Parameters:
n_prototypes_per_class (int) – Prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float)
random_seed (int)
margin (float)
callbacks (list)
use_scan (bool)
batch_size (int | None)
lr_scheduler_kwargs (dict | None)
patience (int | None)
restore_best (bool)
gradient_accumulation_steps (int | None)
ema_decay (float | None)
freeze_params (list | None)
lookahead (dict | None)
mixed_precision (str | None)
gradient_checkpointing (bool)
devices (list | None)
See also
SupervisedPrototypeModelFull list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
Deep and Siamese Variants¶
- class prosemble.models.LVQMLN(hidden_sizes=None, latent_dim=2, activation='sigmoid', beta=10.0, bb_lr=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
LVQ Multi-Layer Network.
An MLP backbone maps inputs into a latent space. Prototypes reside directly in that latent space. The GLVQ loss trains both the backbone and the prototypes jointly via gradient descent.
Architecture:
Input (d) -> MLP -> Latent (latent_dim) | v distance(latent_x, prototypes) | v GLVQ loss
- Parameters:
hidden_sizes (list of int) – Sizes of hidden layers. e.g. [10] for one hidden layer of 10 units.
latent_dim (int) – Dimension of the latent/embedding space where prototypes live.
activation (str) – Activation function: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels.
Transforms X through the backbone, then finds nearest prototype.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.PLVQ(hidden_sizes=None, latent_dim=2, activation='sigmoid', sigma=1.0, loss_type='rslvq', n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Probabilistic LVQ with learned nonlinear transformation.
Combines an MLP backbone (learned metric) with probabilistic soft assignment via Gaussian mixtures. The loss is the negative log-likelihood of the correct class:
\[p(k|x) = \frac{\exp(-d(f(x), w_k)^2 / 2\sigma^2)}{Z}\]\[P(\text{class}|x) = \sum_{k \in \text{class}} p(k|x)\]\[\text{loss} = -\log\frac{P(\text{correct}|x)}{P(\text{all}|x)}\]- Parameters:
latent_dim (int) – Latent space dimension.
activation (str) – Activation: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
sigma (float) – Bandwidth of the Gaussian mixture.
loss_type (str) – ‘rslvq’ (robust, default) or ‘nllr’ (likelihood ratio).
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.SiameseGLVQ(hidden_sizes=None, latent_dim=2, activation='sigmoid', beta=10.0, bb_lr=None, both_path_gradients=True, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Siamese GLVQ — GLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP backbone before computing squared Euclidean distances.
- Parameters:
hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.
latent_dim (int) – Dimension of the embedding space.
activation (str) – Activation function for the backbone MLP. Supported values: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.
both_path_gradients (bool) – If True, compute gradients through both input and prototype paths. If False, prototype path gradients are stopped. Default: True.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)[source]¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.SiameseGMLVQ(hidden_sizes=None, latent_dim=2, omega_dim=None, activation='sigmoid', beta=10.0, bb_lr=None, both_path_gradients=True, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Siamese GMLVQ — GMLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP, then distances are computed using a learned \(\Omega\) matrix in the latent space:
\[d = \|\Omega(f(x) - f(w))\|^2\]- Parameters:
hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.
latent_dim (int) – Dimension of the backbone output (embedding space).
omega_dim (int, optional) – Omega mapping dimension (number of rows in Omega). If None, uses latent_dim (square matrix). Default: None.
activation (str) – Activation function for the backbone MLP. Supported values: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.
both_path_gradients (bool) – If True, compute gradients through both input and prototype paths. If False, prototype path gradients are stopped. Default: True.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.SiameseGTLVQ(hidden_sizes=None, latent_dim=4, subspace_dim=2, activation='sigmoid', beta=10.0, bb_lr=None, both_path_gradients=True, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Siamese GTLVQ — GTLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP, then tangent distances are computed in the latent space using per-prototype subspace bases.
- Parameters:
hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.
latent_dim (int) – Dimension of the backbone output (embedding space).
subspace_dim (int) – Tangent subspace dimension per prototype. Each prototype gets a learned orthonormal basis of this rank in latent space.
activation (str) – Activation function for the backbone MLP. Supported values: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.
both_path_gradients (bool) – If True, compute gradients through both input and prototype paths. If False, prototype path gradients are stopped. Default: True.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
Image LVQ¶
- class prosemble.models.ImageGLVQ(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, activation='relu', beta=10.0, bb_lr=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Image GLVQ — GLVQ with a CNN embedding network.
Both input images and prototype images are passed through the same CNN backbone before computing squared Euclidean distances.
- Parameters:
input_shape (tuple) – Shape of input images as (height, width, channels).
channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].
kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].
latent_dim (int) – Dimension of the CNN embedding space.
activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)[source]¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.ImageGMLVQ(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, omega_dim=None, activation='relu', beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Image GMLVQ — GMLVQ with a CNN embedding network.
Like ImageGLVQ but with a learned \(\Omega\) matrix in latent space:
\[d = \|\Omega(\text{CNN}(x) - \text{CNN}(w))\|^2\]- Parameters:
input_shape (tuple) – Shape of input images as (height, width, channels).
channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].
kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].
latent_dim (int) – Dimension of the CNN embedding space.
omega_dim (int, optional) – Omega mapping dimension (number of rows in Omega). If None, uses latent_dim (square matrix). Default: None.
activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.ImageGTLVQ(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, subspace_dim=2, activation='relu', beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Image GTLVQ — GTLVQ with a CNN embedding network.
Like ImageGLVQ but with per-prototype tangent subspace bases in latent space:
\[d = \|P_k(\text{CNN}(x) - \text{CNN}(w_k))\|^2\]- Parameters:
input_shape (tuple) – Shape of input images as (height, width, channels).
channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].
kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].
latent_dim (int) – Dimension of the CNN embedding space.
subspace_dim (int) – Tangent subspace dimension per prototype. Each prototype gets a learned orthonormal basis of this rank in latent space.
activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
Classification-By-Components¶
- class prosemble.models.CBC(n_components=5, n_classes=2, sigma=1.0, margin=0.3, components_initializer=None, reasonings_initializer=None, similarity_fn=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Classification-By-Components.
Components detect patterns in the input (via similarity), then reasoning matrices determine how each detection contributes evidence for/against each class.
- Parameters:
n_components (int) – Number of components (analogous to prototypes, but classless).
n_classes (int) – Number of output classes.
sigma (float) – Bandwidth for Gaussian similarity in component detection.
margin (float) – Margin for the margin loss.
components_initializer (callable, optional) – Initializer for component vectors. Signature:
(X, key, n_components) -> components. Default: None (selects random training samples).reasonings_initializer (callable, optional) – Initializer for the reasoning matrix. Signature:
(n_components, n_classes, key) -> reasonings. Default: None (initializes near-uniform with small noise).similarity_fn (callable, optional) – Similarity function for component detection. Default: None (uses Gaussian similarity).
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘components’] to freeze the components and only train reasonings. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.ImageCBC(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, n_components=5, n_classes=2, sigma=1.0, activation='relu', margin=0.3, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Image CBC — CBC with a CNN embedding network.
Both input images and component images pass through the same CNN backbone. Detection similarity and reasoning matrices then produce class probabilities.
- Parameters:
input_shape (tuple) – Shape of input images as (height, width, channels).
channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].
kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].
latent_dim (int) – Dimension of the CNN embedding space.
n_components (int) – Number of components (classless prototypes).
n_classes (int) – Number of output classes.
sigma (float) – Bandwidth for Gaussian similarity in component detection.
activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.
margin (float) – Margin for the margin loss.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train components. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)[source]¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
Supervised Neural Gas¶
- class prosemble.models.SRNG(beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Relevance Neural Gas.
Combines three key ideas:
GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification
Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Relevance weighting: per-feature \(\lambda_j\) learned during training
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), SRNG recovers standard GRLVQ.
- Parameters:
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.SMNG(latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Matrix Neural Gas.
Combines three key ideas:
GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification
Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Global \(\Omega\) projection:
\[d(x, w) = \|\Omega(x - w)\|^2\]learns feature correlations and a discriminative subspace
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), SMNG recovers standard GMLVQ.
- Parameters:
latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.SLNG(latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Localized Matrix Neural Gas.
Combines three key ideas:
GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification
Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Per-prototype \(\Omega_k\):
\[d(x, w_k) = \|\Omega_k(x - w_k)\|^2\]learns local metrics adapted to each prototype’s region
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), SLNG recovers standard LGMLVQ.
- Parameters:
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.STNG(subspace_dim=2, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Tangent Neural Gas.
Combines three key ideas:
GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification
Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Tangent subspaces:
\[d(x, w_k) = \|(I - \Omega_k \Omega_k^T)(x - w_k)\|^2\]measures distance in the orthogonal complement of each prototype’s learned invariance subspace
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), STNG recovers standard GTLVQ.
- Parameters:
subspace_dim (int) – Dimension of each prototype’s tangent subspace.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
Cross-Entropy Neural Gas¶
- class prosemble.models.CELVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Cross-Entropy LVQ with Neural Gas neighborhood cooperation.
For each class, prototypes are ranked by distance and weighted by \(\exp(-\text{rank} / \gamma)\). The NG-weighted class distances become logits for cross-entropy loss over all classes simultaneously.
- Parameters:
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict class labels via Winner-Takes-All Competition.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_proba(X)¶
Predict class probabilities via softmin of distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- class prosemble.models.MCELVQ_NG(latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Matrix Cross-Entropy LVQ with Neural Gas neighborhood cooperation.
Combines three key ideas:
Cross-entropy loss: softmax over all-class NG-weighted distances
Neural Gas cooperation: all same-class prototypes participate, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Global \(\Omega\) projection:
\[d(x, w) = \|\Omega(x - w)\|^2\]learns feature correlations and a discriminative subspace
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), MCELVQ-NG recovers a matrix CELVQ.
- Parameters:
latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict_proba(X)[source]¶
Predict calibrated probabilities using \(\Omega\)-transformed distances.
Uses NG-weighted pooling with the learned \(\Omega\) metric, matching the training objective exactly.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.LCELVQ_NG(latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Localized Matrix Cross-Entropy LVQ with Neural Gas cooperation.
Combines three key ideas:
Cross-entropy loss: softmax over all-class NG-weighted distances
Neural Gas cooperation: all same-class prototypes participate, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Per-prototype \(\Omega_k\):
\[d(x, w_k) = \|\Omega_k(x - w_k)\|^2\]learns local metrics adapted to each prototype’s region
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), LCELVQ-NG recovers a localized matrix CELVQ.
- Parameters:
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict_proba(X)[source]¶
Predict calibrated probabilities using per-prototype \(\Omega_k\) distances.
Uses per-class min pooling with the learned local \(\Omega_k\) metrics, consistent with the training objective.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.TCELVQ_NG(subspace_dim=2, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Tangent Cross-Entropy LVQ with Neural Gas neighborhood cooperation.
Combines three key ideas:
Cross-entropy loss: softmax over all-class NG-weighted distances
Neural Gas cooperation: all same-class prototypes participate, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Tangent subspaces:
\[d(x, w_k) = \|(I - \Omega_k \Omega_k^T)(x - w_k)\|^2\]measures distance in the orthogonal complement of each prototype’s learned invariance subspace
The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), TCELVQ-NG recovers a tangent CELVQ.
- Parameters:
subspace_dim (int) – Dimension of each prototype’s tangent subspace.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict_proba(X)[source]¶
Predict calibrated probabilities using tangent distances.
Uses per-class min pooling with tangent distance, consistent with the training objective.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
proba
- Return type:
array of shape (n_samples, n_classes)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
One-Class GLVQ¶
- class prosemble.models.OCGLVQ(n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class Generalized Learning Vector Quantization.
Combines GLVQ’s \(\mu\)-based hypothesis testing with per-prototype visibility thresholds \(\theta_k\) for one-class classification.
- Parameters:
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.
beta (float) – Sigmoid steepness for the transfer function. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- thetas_¶
Learned per-prototype visibility thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)[source]¶
Compute target-likeness scores.
Scores near 1.0 indicate target class, near 0.0 indicate outlier. The decision boundary is at score = 0.5 (where \(d = \theta\)).
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores
- Return type:
array of shape (n_samples,)
- predict(X)[source]¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)[source]¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.OCGRLVQ(n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class GRLVQ with per-feature relevance weighting.
Learns which features are most important for distinguishing target from non-target data in a one-class setting.
- Parameters:
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- relevances_¶
Learned per-feature relevance weights (softmax-normalized).
- Type:
array of shape (n_features,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCGMLVQ(latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class GMLVQ with global Omega projection.
Learns a global linear projection \(\Omega\) that captures feature correlations for one-class classification.
- Parameters:
latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omega_¶
Learned projection matrix.
- Type:
array of shape (n_features, latent_dim)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCLGMLVQ(latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class LGMLVQ with per-prototype Omega projections.
Each prototype learns its own local metric via \(\Omega_k\), allowing different prototypes to attend to different feature subspaces.
- Parameters:
latent_dim (int, optional) – Dimensionality of each projected space. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype projection matrices.
- Type:
array of shape (n_prototypes, n_features, latent_dim)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCGTLVQ(subspace_dim=2, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class GTLVQ with per-prototype tangent subspaces.
Each prototype learns an orthonormal basis \(\Omega_k\) that defines directions of local invariance. Only the distance orthogonal to this subspace is used for classification.
- Parameters:
subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype orthonormal tangent bases.
- Type:
array of shape (n_prototypes, n_features, subspace_dim)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
One-Class GLVQ with Neural Gas¶
- class prosemble.models.OCGLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class GLVQ with Neural Gas neighborhood cooperation.
All prototypes participate in the loss, weighted by their distance rank via exp(-rank / gamma). Uses squared Euclidean distance.
- Parameters:
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- decision_function(X)¶
Compute target-likeness scores.
Scores near 1.0 indicate target class, near 0.0 indicate outlier. The decision boundary is at score = 0.5 (where \(d = \theta\)).
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCGRLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class GRLVQ with Neural Gas neighborhood cooperation.
Learns per-feature relevance weights with NG rank-weighted loss.
- Parameters:
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- relevances_¶
Learned per-feature relevance weights.
- Type:
array of shape (n_features,)
- decision_function(X)¶
Compute scores using relevance-weighted distances.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCGMLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class GMLVQ with Neural Gas neighborhood cooperation.
Learns a global Omega projection with NG rank-weighted loss.
- Parameters:
latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omega_¶
Learned projection matrix.
- Type:
array of shape (n_features, latent_dim)
- decision_function(X)¶
Compute scores using Omega-projected distances.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCLGMLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class LGMLVQ with Neural Gas neighborhood cooperation.
Learns per-prototype local Omega projections with NG rank-weighted loss.
- Parameters:
latent_dim (int, optional) – Dimensionality of each projected space. Default: n_features.
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype projection matrices.
- Type:
array of shape (n_prototypes, n_features, latent_dim)
- decision_function(X)¶
Compute scores using per-prototype Omega-projected distances.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCGTLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class GTLVQ with Neural Gas neighborhood cooperation.
Learns per-prototype tangent subspaces with NG rank-weighted loss.
- Parameters:
subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype orthonormal tangent bases.
- Type:
array of shape (n_prototypes, n_features, subspace_dim)
- decision_function(X)¶
Compute scores using tangent distances.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
One-Class RSLVQ¶
- class prosemble.models.OCRSLVQ(sigma=1.0, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class Robust Soft LVQ.
Combines one-class threshold detection with probabilistic soft-weighting of all prototypes via Gaussian mixture responsibilities.
All prototypes contribute to the one-class decision via Gaussian proximity weights, with standard Euclidean distances.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
n_prototypes (int) – Number of prototypes for the target class. Default: 3.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- thetas_¶
Learned per-prototype acceptance thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)[source]¶
Compute target-likeness scores using soft-weighted distances.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCMRSLVQ(sigma=1.0, latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class Matrix Robust Soft LVQ.
Combines one-class threshold detection with a learned global Omega projection matrix and probabilistic soft-weighting of all prototypes.
All prototypes contribute to the one-class decision via Gaussian proximity weights, with distances computed in the Omega-projected space.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class. Default: 3.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omega_¶
Learned projection matrix.
- Type:
array of shape (n_features, latent_dim)
- thetas_¶
Learned per-prototype visibility thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)[source]¶
Compute target-likeness scores using soft-weighted Omega distances.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCLMRSLVQ(sigma=1.0, latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
One-Class Localized Matrix Robust Soft LVQ.
Each prototype \(k\) has its own \(\Omega_k\) matrix for local metric adaptation, combined with probabilistic soft-weighting and one-class threshold detection.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
latent_dim (int, optional) – Latent space dimensionality per prototype. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class. Default: 3.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype projection matrices.
- Type:
array of shape (n_prototypes, n_features, latent_dim)
- thetas_¶
Learned per-prototype visibility thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)[source]¶
Compute target-likeness scores using local Omega distances.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
One-Class RSLVQ with Neural Gas¶
- class prosemble.models.OCRSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class RSLVQ with Neural Gas neighborhood cooperation.
Combines soft Gaussian mixture responsibilities with NG rank-based cooperation. Uses standard Euclidean distances.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- thetas_¶
Learned per-prototype acceptance thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)¶
Compute target-likeness scores using combined Gaussian+NG weights.
Uses final (converged) gamma for NG modulation at inference time.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCMRSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class Matrix RSLVQ with Neural Gas neighborhood cooperation.
Learns a global Omega projection with combined Gaussian + NG rank-weighted loss for one-class classification.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omega_¶
Learned projection matrix.
- Type:
array of shape (n_features, latent_dim)
- thetas_¶
Learned per-prototype acceptance thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)¶
Compute target-likeness scores using combined Gaussian+NG weights.
Uses final (converged) gamma for NG modulation at inference time.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.OCLMRSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]¶
One-Class Local Matrix RSLVQ with Neural Gas neighborhood cooperation.
Learns per-prototype \(\Omega_k\) projection matrices with combined Gaussian + NG rank-weighted loss for one-class classification.
- Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
latent_dim (int, optional) – Latent space dimensionality per prototype. Default: n_features.
gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
beta (float) – Sigmoid steepness. Default: 10.0.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype projection matrices.
- Type:
array of shape (n_prototypes, n_features, latent_dim)
- thetas_¶
Learned per-prototype acceptance thresholds.
- Type:
array of shape (n_prototypes,)
- decision_function(X)¶
Compute target-likeness scores using combined Gaussian+NG weights.
Uses final (converged) gamma for NG modulation at inference time.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
- Parameters:
- Returns:
labels
- Return type:
array of shape (n_samples,)
SVQ-OCC¶
- class prosemble.models.SVQOCC(n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Vector Quantization One-Class Classification.
Combines Neural Gas representation learning with per-prototype visibility parameters \(\theta_k\) for one-class classification.
- Parameters:
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.
cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.
response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.
sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.
gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.
nu (float) – Degrees of freedom for Student-t response. Default: 1.0.
lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final NG neighborhood range. Default: 0.01.
lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- predict(X)[source]¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels – target_label for target, non_target_label for outliers.
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)[source]¶
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.
- Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
- Returns:
labels
- Return type:
array of shape (n_samples,)
- decision_function(X)[source]¶
Compute summed responsibility scores.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
scores
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.SVQOCC_R(n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Relevance-Weighted SVQ-OCC.
Extends SVQ-OCC with per-feature relevance weighting (like GRLVQ). Learns which features are most important for distinguishing target from non-target data.
- Parameters:
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.
cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.
response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.
sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.
gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.
nu (float) – Degrees of freedom for Student-t response. Default: 1.0.
lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final NG neighborhood range. Default: 0.01.
lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- relevances_¶
Learned per-feature relevance weights (softmax-normalized).
- Type:
array of shape (n_features,)
See also
SVQOCCBase SVQ-OCC model.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels – target_label for target, non_target_label for outliers.
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.
- Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.SVQOCC_M(latent_dim=None, n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Matrix SVQ-OCC with global Omega projection.
Learns a global linear projection \(\Omega\) that captures feature correlations for one-class classification.
- Parameters:
latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.
cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.
response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.
sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.
gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.
nu (float) – Degrees of freedom for Student-t response. Default: 1.0.
lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final NG neighborhood range. Default: 0.01.
lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omega_¶
Learned projection matrix.
- Type:
array of shape (n_features, latent_dim)
See also
SVQOCCBase SVQ-OCC model.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels – target_label for target, non_target_label for outliers.
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.
- Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.SVQOCC_LM(latent_dim=None, n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Local Matrix SVQ-OCC with per-prototype Omega projections.
Each prototype learns its own local metric via \(\Omega_k\), allowing different prototypes to attend to different feature subspaces.
- Parameters:
latent_dim (int, optional) – Dimensionality of each projected space. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.
cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.
response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.
sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.
gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.
nu (float) – Degrees of freedom for Student-t response. Default: 1.0.
lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final NG neighborhood range. Default: 0.01.
lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype projection matrices.
- Type:
array of shape (n_prototypes, n_features, latent_dim)
See also
SVQOCCBase SVQ-OCC model.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels – target_label for target, non_target_label for outliers.
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.
- Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
- Returns:
labels
- Return type:
array of shape (n_samples,)
- class prosemble.models.SVQOCC_T(subspace_dim=2, n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Tangent SVQ-OCC with per-prototype tangent subspaces.
Each prototype learns an orthonormal basis \(\Omega_k\) that defines directions of local invariance. Only the distance orthogonal to this subspace is used for classification.
- Parameters:
subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.
cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.
response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.
sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.
gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.
nu (float) – Degrees of freedom for Student-t response. Default: 1.0.
lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final NG neighborhood range. Default: 0.01.
lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g.
decay_rate,transition_steps). Default: None.prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X, y, n_per_class, key) -> (protos, labels).patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).
- omegas_¶
Learned per-prototype orthonormal tangent bases.
- Type:
array of shape (n_prototypes, n_features, subspace_dim)
See also
SVQOCCBase SVQ-OCC model.
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- predict(X)¶
Predict target or non-target labels.
- Parameters:
X (array-like of shape (n_samples, n_features))
- Returns:
labels – target_label for target, non_target_label for outliers.
- Return type:
array of shape (n_samples,)
- predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)¶
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.
- Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
- Returns:
labels
- Return type:
array of shape (n_samples,)
Fuzzy Clustering¶
- class prosemble.models.FCM(n_clusters, fuzzifier=2.0, max_iter=100, epsilon=1e-05, init_method='random', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
JAX implementation of Fuzzy C-Means clustering.
This implementation provides: - Full vectorization (no Python loops) - JIT compilation for speed - Automatic GPU acceleration - Immutable state management - Numerical stability
Key Differences from NumPy Version:¶
Vectorization: All operations use matrix operations Old: Triple nested loops in centroid computation New: Single matrix multiplication
Functional: Immutable state using NamedTuple Old: In-place updates (self.fit_cent = …) New: Return new state objects
JIT Compilation: Functions compiled to machine code Old: Interpreted Python loops New: Compiled XLA code
GPU Support: Automatic device placement Old: CPU-only NumPy New: GPU/TPU with JAX
- param fuzzifier:
Fuzzification parameter (\(m\)). Must be > 1. - \(m = 1\): Hard clustering (crisp membership) - \(m = 2\): Standard fuzzy clustering - \(m \to \infty\): Maximum fuzziness (equal membership)
- type fuzzifier:
float, default=2.0
- param init_method:
Initialization method for \(U\) matrix: - ‘random’: Random Dirichlet distribution - ‘kmeans++’: K-means++ centroids then compute \(U\)
- type init_method:
str, default=’random’
- param n_clusters:
Number of clusters (must be >= 2).
- type n_clusters:
int
- param max_iter:
Maximum number of iterations.
- type max_iter:
int
- param epsilon:
Convergence threshold.
- type epsilon:
float
- param random_seed:
Random seed for reproducibility.
- type random_seed:
int
- param distance_fn:
Pairwise distance function. Default: squared Euclidean.
- type distance_fn:
callable, optional
- param patience:
Epochs with no improvement before early stopping. Default: None.
- type patience:
int, optional
- param restore_best:
If True, restore centroids from the lowest-objective epoch. Default: False.
- type restore_best:
bool
- param plot_steps:
Whether to visualize clustering progress. Default: False.
- type plot_steps:
bool
- param show_confidence:
Whether to show confidence in visualization. Default: True.
- type show_confidence:
bool
- param show_pca_variance:
Whether to show PCA variance in visualization. Default: True.
- type show_pca_variance:
bool
- param save_plot_path:
Path to save final plot.
- type save_plot_path:
str, optional
- param callbacks:
List of Callback objects for monitoring/visualization.
- type callbacks:
list, optional
- centroids_¶
Cluster centroids after fitting
- Type:
array of shape (n_clusters, n_features)
- U_¶
Fuzzy membership matrix after fitting
- Type:
array of shape (n_samples, n_clusters)
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import FCM >>> >>> # Generate sample data >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]]) >>> >>> # Fit FCM model >>> model = FCM(n_clusters=2, fuzzifier=2.0, max_iter=100) >>> model.fit(X) >>> >>> # Get results >>> labels = model.predict(X) >>> centroids = model.final_centroids() >>> membership = model.predict_proba(X) >>> >>> print(f"Labels: {labels}") >>> print(f"Centroids shape: {centroids.shape}") >>> print(f"Membership shape: {membership.shape}")
References
Bezdek, J. C. (1981). Pattern Recognition with Fuzzy Objective Function Algorithms. Plenum Press, New York.
Dunn, J. C. (1973). A Fuzzy Relative of the ISODATA Process and Its Use in Detecting Compact Well-Separated Clusters.
- predict(X)[source]¶
Predict cluster labels for X.
Assigns each sample to the cluster with highest membership.
- Parameters:
X (Array) – (n_samples, n_features) data
- Returns:
(n_samples,) cluster assignments (0 to n_clusters-1)
- Return type:
labels
- Raises:
RuntimeError – If model not fitted yet
- predict_proba(X)[source]¶
Predict fuzzy membership for X.
- Parameters:
X (Array) – (n_samples, n_features) data
- Returns:
- (n_samples, n_clusters) fuzzy membership matrix
Each row sums to 1, values in [0, 1]
- Return type:
U
- Raises:
RuntimeError – If model not fitted yet
- class prosemble.models.PCM(n_clusters, fuzzifier=2.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=None, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
JAX-based Possibilistic C-Means clustering with GPU acceleration.
PCM is a clustering algorithm that assigns typicality values to data points, representing the degree to which they belong to each cluster. Unlike FCM, the typicality of a point to one cluster is independent of its typicality to other clusters.
- Parameters:
fuzzifier (float, default=2.0) – Fuzzification parameter (\(m > 1\)). Higher values result in fuzzier clusters.
k (float, default=1.0) – Parameter for \(\gamma\) computation. Typical values are in [0.01, 1.0]. Lower values make the algorithm more sensitive to outliers.
init_method ({'fcm', 'random'}, default='fcm') – Initialization method: - ‘fcm’: Initialize using FCM results (recommended) - ‘random’: Random initialization
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Cluster centroids after fitting.
- Type:
ndarray of shape (n_clusters, n_features)
- T_¶
Typicality matrix after fitting.
- Type:
ndarray of shape (n_samples, n_clusters)
- gamma_¶
Scale parameters for each cluster.
- Type:
ndarray of shape (n_clusters,)
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import PCM >>> >>> # Generate sample data >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]]) >>> >>> # Fit PCM model >>> model = PCM(n_clusters=2, fuzzifier=2.0, k=1.0) >>> model.fit(X) >>> >>> # Get cluster assignments >>> labels = model.predict(X) >>> >>> # Get typicality values >>> typicalities = model.predict_proba(X)
Notes
PCM is less sensitive to outliers than FCM because typicality values are computed independently for each cluster.
The parameter \(k\) controls the sensitivity to outliers. Smaller values make the algorithm more sensitive.
Initialization from FCM (init_method=’fcm’) is recommended as it provides better starting points than random initialization.
All computations are JIT-compiled and can run on GPU if available.
- class prosemble.models.FPCM(n_clusters, fuzzifier=2.0, eta=2.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Fuzzy Possibilistic C-Means clustering with JAX.
FPCM maintains TWO matrices: \(U\) (fuzzy membership) and \(T\) (typicality). \(U\) has row-sum-to-1 constraint (standard FCM), while \(T\) has column-sum-to-1 constraint per the original Pal, Pal & Bezdek (1997) formulation.
Algorithm:
Initialize \(U\) and \(T\) (randomly or using FCM)
Update centroids using combined fuzzy and typicality weights
Update \(U\) using FCM rule with fuzzifier \(m\) (row-normalized)
Update \(T\) with column-normalization
Repeat until convergence
Objective function:
\[J = \sum_i \sum_j \left[u_{ij}^m + t_{ij}^\eta\right] \|x_i - v_j\|^2\]- Reference:
Pal, N. R., Pal, K., & Bezdek, J. C. (1997). A mixed c-means clustering model. FUZZ-IEEE.
- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).
eta (float, default=2.0) – Fuzziness parameter for \(T\) matrix (must be > 1.0).
init_method ({'random', 'fcm'}, default='fcm') – Method for initializing \(U\) and \(T\) matrices.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- U_¶
Final fuzzy membership matrix
- Type:
array, shape (n_samples, n_clusters)
- T_¶
Final possibilistic typicality matrix
- Type:
array, shape (n_samples, n_clusters)
- objective_history_¶
Objective value at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import FPCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = FPCM(n_clusters=2, fuzzifier=2.0, eta=2.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X) >>> U = model.predict_proba(X) >>> T = model.get_typicality(X)
- fit(X, initial_centroids=None, resume=False)[source]¶
Fit FPCM model to data.
- Parameters:
- Returns:
Fitted model
- Return type:
self
- Raises:
ValueError – If n_samples < n_clusters
- predict(X)[source]¶
Predict cluster labels for new data.
- Parameters:
X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)
- Returns:
Cluster labels, shape (n_samples,)
- Return type:
labels
- Raises:
ValueError – If model has not been fitted
- class prosemble.models.PFCM(n_clusters, fuzzifier=2.0, eta=2.0, a=1.0, b=1.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
JAX implementation of Possibilistic Fuzzy C-Means clustering.
PFCM combines fuzzy membership and typicality for robust clustering.
- Parameters:
fuzzifier (float, default=2.0) – Fuzzification parameter for membership (\(m\)). Must be > 1.
eta (float, default=2.0) – Fuzzification parameter for typicality (\(\eta\)). Must be > 1.
a (float, default=1.0) – Weight for fuzzy membership term. Must be >= 0.
b (float, default=1.0) – Weight for typicality term. Must be >= 0.
k (float, default=1.0) – Parameter for \(\gamma\) computation. Must be > 0.
init_method (str, default='fcm') – Initialization method: ‘fcm’ or ‘random’.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array
- U_¶
Final fuzzy membership matrix
- Type:
array
- T_¶
Final typicality matrix
- Type:
array
- gamma_¶
Final scale parameters
- Type:
array
- class prosemble.models.AFCM(n_clusters, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Adaptive Fuzzy C-Means clustering with JAX.
AFCM is an adaptive variant that combines fuzzy and possibilistic approaches with specific parameter combinations.
Key features: - Centroids use \(a \cdot U^m + b \cdot T\) (\(T\) to power 1, not \(m\)!) - \(\gamma\) computed with Euclidean distance (not squared) - Exponential \(T\) update with parameter \(b\) - Standard FCM \(U\) update
Algorithm:
Initialize \(U\) using FCM
Compute \(\gamma\) parameters using Euclidean distance
Update \(T\) using exponential update
Update \(U\) using standard FCM rule
Update centroids using combined fuzzy-possibilistic weights
Repeat until convergence
Objective function:
\[J = \sum_i \sum_j \left[d_{ij}^2 \cdot (a \cdot u_{ij}^m + b \cdot t_{ij})\right] + \sum_j \left[\gamma_j \cdot \sum_i (t_{ij} \log t_{ij} - t_{ij})\right]\]- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).
a (float, default=1.0) – Weight for fuzzy membership term (must be > 0).
b (float, default=1.0) – Weight for typicality term (must be > 0).
k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).
init_method ({'fcm'}, default='fcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- U_¶
Final fuzzy membership matrix
- Type:
array, shape (n_samples, n_clusters)
- T_¶
Final typicality matrix
- Type:
array, shape (n_samples, n_clusters)
- gamma_¶
Final scale parameters
- Type:
array, shape (n_clusters,)
- objective_history_¶
Objective values at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import AFCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = AFCM(n_clusters=2, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
- class prosemble.models.HCM(n_clusters, max_iter=100, epsilon=1e-05, init_method='random', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Hard C-Means (K-Means) clustering with JAX.
HCM assigns each data point to exactly one cluster (hard assignment) based on the nearest centroid. This is the classic K-Means algorithm.
Algorithm:
Initialize centroids randomly or from data
Assign each point to nearest centroid
Update centroids as mean of assigned points
Repeat until convergence
Objective function:
\[J = \sum_i \|x_i - v_{l_i}\|^2\]- Parameters:
init_method ({'random', 'kmeans++'}, default='random') – Method for initializing centroids.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- labels_¶
Hard cluster assignments for training data
- Type:
array, shape (n_samples,)
- objective_history_¶
Objective value at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import HCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = HCM(n_clusters=2, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
- class prosemble.models.IPCM(n_clusters, fuzzifier=2.0, tipifier=2.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Improved Possibilistic C-Means clustering with JAX.
IPCM uses a two-phase approach to improve clustering performance: - Phase 0: Initialize \(\gamma\) using fuzzy membership only - Phase 1: Refine \(\gamma\) using both membership and typicality
Key differences from PCM: - Uses product of \(U^{m_f}\) and \(T^{m_p}\) in centroid computation - Modified \(U\) update that depends on \(T\) - Two-phase \(\gamma\) computation
Algorithm (Phase 0):
Initialize \(U\) using FCM, \(T = 0\)
Compute \(\gamma\) parameters from fuzzy membership
Update typicality matrix \(T\)
Update membership matrix \(U\)
Update centroids using combined U and T weights
Repeat until convergence
Algorithm (Phase 1):
Recompute \(\gamma\) using both \(U\) and \(T\)
Continue iterations with new gamma
Objective function:
\[J = \sum_i \sum_j u_{ij}^{m_f} \cdot t_{ij}^{m_p} \cdot d_{ij}^2 + \sum_j \gamma_j \sum_i (1 - t_{ij})^{m_p} \cdot u_{ij}^{m_f}\]- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (\(m_f\), must be > 1.0).
tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (\(m_p\), must be > 1.0).
k (float, default=1.0) – Scaling parameter for \(\gamma\) in phase 1 (must be > 0).
init_method ({'fcm'}, default='fcm') – Method for initializing \(U\) matrix (must use FCM).
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- U_¶
Final fuzzy membership matrix
- Type:
array, shape (n_samples, n_clusters)
- T_¶
Final typicality matrix
- Type:
array, shape (n_samples, n_clusters)
- gamma_¶
Final scale parameters
- Type:
array, shape (n_clusters,)
- objective_history_¶
Objective value at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import IPCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = IPCM(n_clusters=2, fuzzifier=2.0, tipifier=2.0, k=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
- class prosemble.models.IPCM2(n_clusters, fuzzifier=2.0, tipifier=2.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Improved Possibilistic C-Means 2 clustering with JAX.
IPCM2 is a variant of IPCM with key differences: - Uses exponential \(T\) update: \(t_{ij} = \exp(-d_{ij}^2 / \gamma_j)\) - Centroids use \(U^{m_f} \cdot T\) (\(T\) without power!) - Modified \(U\) update with exponential distance - Different objective function
Algorithm (Phase 0):
Initialize \(U\) using FCM, \(T = 0\)
Compute \(\gamma\) parameters from fuzzy membership
Update \(T\) using exponential update
Update \(U\) with modified distance
Update centroids using combined U and T weights
Repeat until convergence
Algorithm (Phase 1):
Recompute \(\gamma\) using both \(U\) and \(T\)
Continue iterations with new gamma
Objective function:
\[J = \sum_i \sum_j u_{ij}^{m_f} \cdot t_{ij} \cdot d_{ij}^2 + \sum_j \gamma_j \sum_i (t_{ij} \log t_{ij} - t_{ij} + 1) \cdot u_{ij}^{m_f}\]- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (\(m_f\), must be > 1.0).
tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (\(m_p\), must be > 1.0).
init_method ({'fcm'}, default='fcm') – Method for initializing \(U\) matrix.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- U_¶
Final fuzzy membership matrix
- Type:
array, shape (n_samples, n_clusters)
- T_¶
Final typicality matrix
- Type:
array, shape (n_samples, n_clusters)
- gamma_¶
Final scale parameters
- Type:
array, shape (n_clusters,)
- objective_history_¶
Objective values at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import IPCM2 >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = IPCM2(n_clusters=2, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
Kernel Clustering¶
- class prosemble.models.KFCM(n_clusters, fuzzifier=2.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='random', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Fuzzy C-Means clustering with JAX.
KFCM uses a Gaussian kernel to map data into a high-dimensional feature space where clustering is performed. This allows handling non-linearly separable data.
Kernel:
\[K(x, y) = \exp\left(-\frac{\|x - y\|^2}{\sigma^2}\right)\]Kernel distance in feature space:
\[\|\varphi(x) - \varphi(y)\|^2 = 2(1 - K(x, y))\]Algorithm:
Initialize \(U\) randomly
Update centroids (kernel-weighted)
Update \(U\) using kernel distance
Repeat until convergence
Objective function:
\[J = 2 \sum_i \sum_j u_{ij}^m (1 - K(x_i, v_j))\]- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'random'}, default='random') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- U_¶
Final fuzzy membership matrix
- Type:
array, shape (n_samples, n_clusters)
- objective_history_¶
Objective values at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import KFCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = KFCM(n_clusters=2, sigma=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
- class prosemble.models.KPCM(n_clusters, fuzzifier=2.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Possibilistic C-Means clustering with JAX.
KPCM extends PCM to kernel space using Gaussian kernel, allowing handling of non-linearly separable data while maintaining possibilistic properties.
Kernel:
\[K(x, y) = \exp\left(-\frac{\|x - y\|^2}{\sigma^2}\right)\]Kernel distance:
\[d_K(x, v) = 2(1 - K(x, v))\]Algorithm:
Initialize using KFCM
Compute \(\gamma\) parameters
Update typicality matrix \(T\)
Update centroids (kernel-weighted)
Repeat until convergence
Objective function:
\[J = \sum_i \sum_j t_{ij}^m \cdot d_K(x_i, v_j) + \sum_j \gamma_j \sum_i (1 - t_{ij})^m\]- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).
k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'kfcm'}, default='kfcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- centroids_¶
Final cluster centroids
- Type:
array, shape (n_clusters, n_features)
- T_¶
Final typicality matrix
- Type:
array, shape (n_samples, n_clusters)
- gamma_¶
Final scale parameters
- Type:
array, shape (n_clusters,)
- objective_history_¶
Objective values at each iteration
- Type:
array
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import KPCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = KPCM(n_clusters=2, sigma=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
- class prosemble.models.KFPCM(n_clusters, fuzzifier=2.0, eta=2.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Fuzzy Possibilistic C-Means with JAX.
KFPCM maintains two matrices (\(U\) and \(T\)) in kernel space. \(U\) has row-sum-to-1 constraint (standard FCM), while \(T\) has column-sum-to-1 constraint per the original Pal, Pal & Bezdek (1997) FPCM formulation.
- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).
eta (float, default=2.0) – Fuzziness parameter for \(T\) matrix (must be > 1.0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'kfcm', 'random'}, default='kfcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- class prosemble.models.KPFCM(n_clusters, fuzzifier=2.0, eta=2.0, a=1.0, b=1.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Possibilistic Fuzzy C-Means with JAX.
KPFCM combines fuzzy membership (\(U\)) and typicality (\(T\)) in kernel space with weights \(a\) and \(b\).
- Parameters:
fuzzifier (float, default=2.0) – Fuzzification parameter for membership (must be > 1.0).
eta (float, default=2.0) – Fuzzification parameter for typicality (must be > 1.0).
a (float, default=1.0) – Weight for fuzzy membership term (must be > 0).
b (float, default=1.0) – Weight for typicality term (must be > 0).
k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'kfcm'}, default='kfcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- class prosemble.models.KAFCM(n_clusters, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Adaptive Fuzzy C-Means clustering with JAX.
KAFCM extends AFCM to kernel space, combining fuzzy and possibilistic approaches with kernel-based distance measures.
Kernel distance: \(d_K(x, v) = 2(1 - K(x, v))\)
Algorithm:
Initialize \(U\) using KFCM
Compute \(\gamma\) parameters using kernel distance
Update \(T\) using exponential kernel update
Update \(U\) using standard KFCM rule
Update centroids (kernel-weighted with combined weights)
Repeat until convergence
- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).
a (float, default=1.0) – Weight for fuzzy membership (must be > 0).
b (float, default=1.0) – Weight for typicality (must be > 0).
k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'kfcm'}, default='kfcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
Examples
>>> import jax.numpy as jnp >>> from prosemble.models import KAFCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = KAFCM(n_clusters=2, sigma=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X)
- class prosemble.models.KIPCM(n_clusters, fuzzifier=2.0, tipifier=2.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Improved Possibilistic C-Means with JAX.
KIPCM uses two-phase approach in kernel space with product-based centroids.
- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).
tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (must be > 1.0).
k (float, default=1.0) – Scaling parameter for \(\gamma\) in phase 1 (must be > 0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'kfcm'}, default='kfcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
- class prosemble.models.KIPCM2(n_clusters, fuzzifier=2.0, tipifier=2.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]¶
Kernel Improved Possibilistic C-Means 2 with JAX.
KIPCM2 uses exponential \(T\) update and modified objective in kernel space.
- Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).
tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (must be > 1.0).
sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).
init_method ({'kfcm'}, default='kfcm') – Initialization method.
n_clusters (int) – Number of clusters (must be >= 2).
max_iter (int) – Maximum number of iterations.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.
plot_steps (bool) – Whether to visualize clustering progress. Default: False.
show_confidence (bool) – Whether to show confidence in visualization. Default: True.
show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.
save_plot_path (str, optional) – Path to save final plot.
callbacks (list, optional) – List of Callback objects for monitoring/visualization.
Topology-Preserving Models¶
- class prosemble.models.NeuralGas(n_prototypes, lr_init=0.5, lr_final=0.01, lambda_init=None, lambda_final=0.01, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]¶
Neural Gas.
Updates all prototypes based on rank-distance:
\[h(\text{rank}, \lambda) = \exp(-\text{rank} / \lambda)\]\[w_k \leftarrow w_k + \varepsilon \cdot h(\text{rank}_k) \cdot (x - w_k)\]Both \(\varepsilon\) and \(\lambda\) decay exponentially during training.
- Parameters:
lr_init (float) – Initial learning rate.
lr_final (float) – Final learning rate.
lambda_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final neighborhood range.
n_prototypes (int) – Number of prototypes/nodes.
max_iter (int) – Maximum training iterations.
lr (float) – Initial learning rate.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed.
distance_fn (callable, optional) – Distance function.
callbacks (list, optional) – Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
- predict(X)¶
Assign each sample to closest prototype (BMU).
- transform(X)¶
Return distance matrix to all prototypes.
- class prosemble.models.GrowingNeuralGas(max_nodes=100, lr_winner=0.1, lr_neighbor=0.01, max_age=50, insert_interval=100, error_decay=0.995, n_prototypes=2, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]¶
Growing Neural Gas.
Starts with 2 nodes and grows by inserting nodes near the highest-error units. Connections between nodes have ages; old connections are removed.
- Parameters:
max_nodes (int) – Maximum number of nodes.
lr_winner (float) – Learning rate for the winning node.
lr_neighbor (float) – Learning rate for neighbors of the winner.
max_age (int) – Maximum age before an edge is removed.
insert_interval (int) – Insert a new node every this many steps.
error_decay (float) – Error decay factor applied to all nodes.
n_prototypes (int) – Number of prototypes/nodes.
max_iter (int) – Maximum training iterations.
lr (float) – Initial learning rate.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed.
distance_fn (callable, optional) – Distance function.
callbacks (list, optional) – Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
- predict(X)¶
Assign each sample to closest prototype (BMU).
- class prosemble.models.KohonenSOM(grid_height=10, grid_width=10, sigma_init=None, sigma_final=0.5, lr_init=0.5, lr_final=0.01, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]¶
Standard Kohonen Self-Organizing Map.
Uses squared Euclidean distance for BMU selection, Gaussian neighborhood function, exponential decay for sigma and learning rate, and batch updates.
- Parameters:
grid_height (int) – Height of the 2D grid.
grid_width (int) – Width of the 2D grid.
sigma_init (float, optional) – Initial neighborhood radius. Default: max(grid_height, grid_width) / 2.
sigma_final (float) – Final neighborhood radius.
lr_init (float) – Initial learning rate.
lr_final (float) – Final learning rate.
max_iter (int) – Maximum training iterations.
lr (float) – Initial learning rate.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed.
distance_fn (callable, optional) – Distance function.
callbacks (list, optional) – Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
- predict(X)¶
Assign each sample to closest prototype (BMU).
- class prosemble.models.HeskesSOM(grid_height=10, grid_width=10, sigma_init=None, sigma_final=0.5, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]¶
Heskes Self-Organizing Map.
Uses a modified BMU definition that considers the neighborhood structure, and a pure batch update (weighted average of data). Guarantees monotonic decrease of the Heskes energy function.
Differences from KohonenSOM:
BMU is chosen to minimize neighborhood-weighted distance sum, not raw distance to closest prototype.
Prototypes are updated via weighted average (no learning rate).
Energy is guaranteed to decrease monotonically.
- Parameters:
grid_height (int) – Height of the 2D grid.
grid_width (int) – Width of the 2D grid.
sigma_init (float, optional) – Initial neighborhood radius. Default: max(grid_height, grid_width) / 2.
sigma_final (float) – Final neighborhood radius.
max_iter (int) – Maximum training iterations.
lr (float) – Initial learning rate.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed.
distance_fn (callable, optional) – Distance function.
callbacks (list, optional) – Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
- predict(X)¶
Assign each sample to closest prototype (BMU).
Riemannian Models¶
- class prosemble.models.RiemannianSRNG(manifold, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Riemannian Neural Gas.
Combines three key ideas:
GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification
Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Geodesic distance: \(d(x, w)\) computed via the manifold’s intrinsic metric (matrix logarithm + Frobenius norm)
Prototypes live on the manifold and are updated via projected gradient descent: optax computes Euclidean gradients, then
manifold.project()maps prototypes back to the manifold after each step.The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), RiemannianSRNG recovers a Riemannian GLVQ.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.
tau (float) – Injectivity radius safety factor for manifold projection. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True, use jax.lax.scan for training (faster, JIT-compiled). If False (default), use a Python for-loop with true early stopping.
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler. Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Default: ‘stratified_random’.
patience (int, optional) – Number of consecutive epochs with no improvement before stopping. Default: None.
restore_best (bool) – If True, restore parameters that achieved the lowest loss. Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps. Default: None.
ema_decay (float, optional) – Exponential moving average decay for parameters. Default: None.
freeze_params (list of str, optional) – List of parameter group names to freeze. Default: None.
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Default: None.
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. Default: None.
- predict(X)[source]¶
Predict class labels using geodesic distance.
- Parameters:
X (array-like of shape (n_samples, n_features_flat))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.RiemannianSMNG(manifold, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Riemannian Matrix Neural Gas.
Extends RiemannianSRNG with a global metric adaptation matrix \(\Omega\) applied in the tangent space. The distance is:
\[d(x, w_k) = \|\Omega \cdot \text{Log}_{w_k}(x)_{\text{flat}}\|^2\]where \(\text{Log}_{w_k}(x)\) is the logarithmic map at prototype \(w_k\), flattened to a vector.
The learned relevance matrix \(\Lambda = \Omega^T \Omega\) captures feature correlations in the tangent space.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
latent_dim (int, optional) – Projection dimensionality for omega. Default: n_features (square).
beta (float) – Transfer function steepness.
gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step decay factor for gamma.
tau (float) – Injectivity radius safety factor. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
optimizer (str or optax optimizer, optional) – Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping.
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True, use jax.lax.scan. Default: False.
batch_size (int, optional) – Mini-batch size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.
lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.
prototypes_initializer (str or callable, optional) – How to initialize prototypes.
patience (int, optional) – Epochs with no improvement before stopping.
restore_best (bool) – Restore best parameters. Default: False.
class_weight (dict or 'balanced', optional) – Class weights.
gradient_accumulation_steps (int, optional) – Gradient accumulation steps.
ema_decay (float, optional) – EMA decay for parameters.
freeze_params (list of str, optional) – Parameter groups to freeze.
lookahead (dict, optional) – Lookahead optimizer config.
mixed_precision (str or None, optional) – Mixed precision dtype.
- predict(X)[source]¶
Predict using tangent-space omega metric.
- Parameters:
X (array-like of shape (n_samples, n_features_flat))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- relevance_matrix()[source]¶
Return learned relevance matrix Lambda = Omega^T Omega.
- Return type:
array of shape (d_flat, d_flat)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.RiemannianSLNG(manifold, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Riemannian Localized Matrix Neural Gas.
Extends RiemannianSRNG with per-prototype metric adaptation. Each prototype \(w_k\) has its own matrix \(\Omega_k\) applied in the tangent space:
\[d(x, w_k) = \|\Omega_k \cdot \text{Log}_{w_k}(x)_{\text{flat}}\|^2\]Since each \(\Omega_k\) operates on tangent vectors at \(w_k\) (all in the same tangent space \(T_{w_k}M\)), this is geometrically well-defined.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
latent_dim (int, optional) – Projection dimensionality for each omega. Default: n_features.
beta (float) – Transfer function steepness.
gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step decay factor for gamma.
tau (float) – Injectivity radius safety factor. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
optimizer (str or optax optimizer, optional) – Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping.
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True, use jax.lax.scan. Default: False.
batch_size (int, optional) – Mini-batch size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.
lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.
prototypes_initializer (str or callable, optional) – How to initialize prototypes.
patience (int, optional) – Epochs with no improvement before stopping.
restore_best (bool) – Restore best parameters. Default: False.
class_weight (dict or 'balanced', optional) – Class weights.
gradient_accumulation_steps (int, optional) – Gradient accumulation steps.
ema_decay (float, optional) – EMA decay for parameters.
freeze_params (list of str, optional) – Parameter groups to freeze.
lookahead (dict, optional) – Lookahead optimizer config.
mixed_precision (str or None, optional) – Mixed precision dtype.
- predict(X)[source]¶
Predict using per-prototype tangent-space omega metric.
- Parameters:
X (array-like of shape (n_samples, n_features_flat))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.RiemannianSTNG(manifold, subspace_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]¶
Supervised Riemannian Tangent Neural Gas.
Extends RiemannianSRNG with per-prototype tangent subspace projection. Each prototype \(w_k\) has an orthonormal basis \(\Omega_k\) defining an invariant subspace; the distance measures how far the tangent vector lies outside this subspace:
\[d(x, w_k) = \|(I - \Omega_k \Omega_k^T) \cdot \text{Log}_{w_k}(x)_{\text{flat}}\|^2\]The subspace bases are re-orthogonalized after each gradient step.
- Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.
subspace_dim (int, optional) – Dimensionality of each prototype’s tangent subspace. Default: n_features // 2.
beta (float) – Transfer function steepness.
gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step decay factor for gamma.
tau (float) – Injectivity radius safety factor. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
max_iter (int) – Maximum training iterations.
lr (float) – Learning rate.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
optimizer (str or optax optimizer, optional) – Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping.
margin (float) – Margin for loss computation.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True, use jax.lax.scan. Default: False.
batch_size (int, optional) – Mini-batch size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.
lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.
prototypes_initializer (str or callable, optional) – How to initialize prototypes.
patience (int, optional) – Epochs with no improvement before stopping.
restore_best (bool) – Restore best parameters. Default: False.
class_weight (dict or 'balanced', optional) – Class weights.
gradient_accumulation_steps (int, optional) – Gradient accumulation steps.
ema_decay (float, optional) – EMA decay for parameters.
freeze_params (list of str, optional) – Parameter groups to freeze.
lookahead (dict, optional) – Lookahead optimizer config.
mixed_precision (str or None, optional) – Mixed precision dtype.
- predict(X)[source]¶
Predict using tangent subspace distance.
- Parameters:
X (array-like of shape (n_samples, n_features_flat))
- Returns:
labels
- Return type:
array of shape (n_samples,)
- fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.
- Return type:
self
- class prosemble.models.RiemannianNeuralGas(manifold, n_prototypes, lr_init=0.3, lr_final=0.01, lambda_init=None, lambda_final=0.01, tau=0.9, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, patience=None, restore_best=False)[source]¶
Riemannian Neural Gas.
Learns prototypes on a Riemannian manifold using geodesic distances and exponential/logarithmic maps for updates.
- Parameters:
manifold (Manifold) – Manifold instance (SO, SPD, or Grassmannian from
prosemble.core.manifolds). Must satisfy theManifoldprotocol.n_prototypes (int) – Number of prototypes/nodes.
lr_init (float) – Initial learning rate. Default: 0.3.
lr_final (float) – Final learning rate. Default: 0.01.
lambda_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.
lambda_final (float) – Final neighborhood range. Default: 0.01.
tau (float) – Safety factor for injectivity radius bound. Default: 0.9.
max_iter (int) – Maximum training iterations.
lr (float) – Initial learning rate.
epsilon (float) – Convergence threshold.
random_seed (int) – Random seed.
distance_fn (callable, optional) – Distance function.
callbacks (list, optional) – Callback objects.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
- fit(X)[source]¶
Fit Riemannian Neural Gas.
- Parameters:
X (array of shape
(n_samples, *point_shape)) – Data points on the manifold.- Return type:
self
Utility Models¶
- class prosemble.models.KMeansPlusPlus(n_clusters=3, max_iter=100, epsilon=1e-05, random_seed=None, plot_steps=False)[source]¶
K-means++ clustering with JAX (Hard C-Means with smart initialization)
K-means++ is an algorithm for choosing initial cluster centers with better convergence properties than random initialization.
Algorithm:
Choose first center uniformly at random from data points
For each data point x, compute D(x), the distance to nearest center
Choose next center with probability proportional to D(x)²
Repeat until k centers are chosen
Run standard k-means (HCM) with these initial centers
- Parameters:
n_clusters (int, default=3) – Number of clusters
max_iter (int, default=100) – Maximum number of iterations
epsilon (float, default=1e-5) – Convergence tolerance
random_seed (int, optional) – Random seed for reproducibility
plot_steps (bool, default=False) – Whether to enable visualization (requires LiveVisualizer)
- centroids_¶
Cluster centers
- Type:
array of shape (n_clusters, n_features)
- labels_¶
Labels of each point
- Type:
array of shape (n_samples,)
See also
HCMHard C-Means model used internally after initialization.
- class prosemble.models.KNN(n_neighbors=5)[source]¶
K-Nearest Neighbors (KNN) with JAX
KNN is a simple, non-parametric classifier that predicts based on the k nearest training samples.
Algorithm:
For each test sample, compute distances to all training samples, find k nearest neighbors, predict label as mode (most common) of k neighbors’ labels, and compute probability as frequency of predicted label.
- Parameters:
n_neighbors (int, default=5) – Number of neighbors to use
- class prosemble.models.NPC(n_classes=3, max_iter=10, tol=0.8, random_state=None)[source]¶
Noise Possibilistic C-Means (NPC) with JAX
NPC is a supervised prototype-based classifier that iteratively optimizes prototypes based on accuracy metric. Uses softmin for probability estimation.
Algorithm:
Initialize prototypes (one per class)
Predict labels using nearest prototype
Compute accuracy
If accuracy >= threshold or max_iter reached, stop
Otherwise, recompute prototypes and repeat
Softmin function: \(\text{softmin}(x_i) = \exp(-x_i) / \sum_j \exp(-x_j)\)
- Parameters:
- prosemble.models.Kmeans¶
alias of
KMeansPlusPlus
- class prosemble.models.SOM(grid_size=None, max_iter=None, learning_rate=0.5, sigma=1.0, random_state=None)[source]¶
Self-Organizing Maps (SOM) with JAX
SOM is an unsupervised learning algorithm that creates a low-dimensional (typically 2D) representation of high-dimensional data while preserving topological relationships.
Algorithm:
Initialize grid of neurons with random weights
For each iteration, select random sample from data, find Best Matching Unit (BMU), update BMU and its neighbors towards the sample, and decay learning rate and neighborhood range.
- Parameters:
grid_size (int, optional) – Size of the SOM grid (grid_size x grid_size). If None, computed as int(sqrt(5 * sqrt(n_samples)))
max_iter (int, optional) – Number of training iterations. If None, set to 500 * grid_size^2
learning_rate (float, default=0.5) – Initial learning rate
sigma (float, default=1.0) – Initial neighborhood radius
random_state (int, optional) – Random seed for reproducibility
- class prosemble.models.BGPC(n_clusters=3, max_iter=100, tol=0.0001, alpha_init=1.0, beta_init=0.1, beta_final=10.0, init='fcm', random_state=None)[source]¶
Bayesian Graded Possibilistic C-Means (BGPC) with JAX
BGPC uses exponential weighting with time-decaying alpha and beta parameters.
Algorithm:
Compute membership weights using exponential distance
Normalize memberships using partition function Z
Update centroids as weighted mean of data
Update beta and alpha with decay schedules
Repeat until convergence
- Parameters:
n_clusters (int) – Number of clusters
max_iter (int, default=100) – Maximum number of iterations
tol (float, default=1e-4) – Convergence tolerance
alpha_init (float, default=1.0) – Initial alpha parameter
beta_init (float, default=0.1) – Initial beta parameter (starting value for decay)
beta_final (float, default=10.0) – Final beta parameter (ending value for decay)
init (str, default='fcm') – Initialization method: ‘random’, ‘fcm’, or ‘kmeans++’
random_state (int, optional) – Random seed for reproducibility