callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class labels via Winner-Takes-All Competition.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1).
Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1).
Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
latent_dim (int, optional) – Dimensionality of the latent space. If None, uses input dim.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1).
Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities using \(\Omega\)-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1).
Samples below this threshold are rejected (labeled -1) when using
predict_with_rejection(). Default is None (no rejection).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities using local \(\Omega\)-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Neural Gas cooperation: all prototypes weighted by rank via
\(\exp(-\text{rank} / \gamma)\)
Euclidean distance
The NG neighborhood modulates RSLVQ’s Gaussian probabilities,
emphasizing nearby prototypes. \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
Parameters:
sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1).
Samples below this threshold are rejected (label -1).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Neural Gas cooperation: all prototypes weighted by rank via
\(\exp(-\text{rank} / \gamma)\)
Global \(\Omega\) matrix for metric adaptation:
\[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]
Parameters:
sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.
latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1).
Samples below this threshold are rejected (label -1).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities using \(\Omega\)-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Localized Matrix Robust Soft LVQ with Neural Gas Cooperation.
Each prototype \(k\) has its own \(\Omega_k\) matrix. Combined with RSLVQ
probabilistic loss and NG rank-based neighborhood cooperation.
Parameters:
sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1).
Samples below this threshold are rejected (label -1).
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities using local \(\Omega_k\)-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
An MLP backbone maps inputs into a latent space. Prototypes reside
directly in that latent space. The GLVQ loss trains both the backbone
and the prototypes jointly via gradient descent.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Probabilistic LVQ with learned nonlinear transformation.
Combines an MLP backbone (learned metric) with probabilistic
soft assignment via Gaussian mixtures. The loss is the negative
log-likelihood of the correct class:
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Siamese GLVQ — GLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP
backbone before computing squared Euclidean distances.
Parameters:
hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.
latent_dim (int) – Dimension of the embedding space.
activation (str) – Activation function for the backbone MLP. Supported values:
‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None,
uses the same lr as prototypes. Default: None.
both_path_gradients (bool) – If True, compute gradients through both input and prototype
paths. If False, prototype path gradients are stopped.
Default: True.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities via softmin of distances.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
proba
Return type:
array of shape (n_samples, n_classes)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Siamese GMLVQ — GMLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP,
then distances are computed using a learned \(\Omega\) matrix in the
latent space:
\[d = \|\Omega(f(x) - f(w))\|^2\]
Parameters:
hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.
latent_dim (int) – Dimension of the backbone output (embedding space).
omega_dim (int, optional) – Omega mapping dimension (number of rows in Omega). If None,
uses latent_dim (square matrix). Default: None.
activation (str) – Activation function for the backbone MLP. Supported values:
‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None,
uses the same lr as prototypes. Default: None.
both_path_gradients (bool) – If True, compute gradients through both input and prototype
paths. If False, prototype path gradients are stopped.
Default: True.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class labels via Winner-Takes-All Competition.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Siamese GTLVQ — GTLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP,
then tangent distances are computed in the latent space using
per-prototype subspace bases.
Parameters:
hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.
latent_dim (int) – Dimension of the backbone output (embedding space).
subspace_dim (int) – Tangent subspace dimension per prototype. Each prototype gets a
learned orthonormal basis of this rank in latent space.
activation (str) – Activation function for the backbone MLP. Supported values:
‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.
beta (float) – Transfer function parameter for GLVQ loss.
bb_lr (float, optional) – Separate learning rate for the backbone network. If None,
uses the same lr as prototypes. Default: None.
both_path_gradients (bool) – If True, compute gradients through both input and prototype
paths. If False, prototype path gradients are stopped.
Default: True.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class labels via Winner-Takes-All Competition.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities via softmin of distances.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
proba
Return type:
array of shape (n_samples, n_classes)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class labels via Winner-Takes-All Competition.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class labels via Winner-Takes-All Competition.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Components detect patterns in the input (via similarity), then
reasoning matrices determine how each detection contributes
evidence for/against each class.
Parameters:
n_components (int) – Number of components (analogous to prototypes, but classless).
components_initializer (callable, optional) – Initializer for component vectors. Signature:
(X,key,n_components)->components. Default: None
(selects random training samples).
reasonings_initializer (callable, optional) – Initializer for the reasoning matrix. Signature:
(n_components,n_classes,key)->reasonings. Default: None
(initializes near-uniform with small noise).
similarity_fn (callable, optional) – Similarity function for component detection. Default: None
(uses Gaussian similarity).
n_prototypes_per_class (int) – Number of prototypes per class.
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.
Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘components’] to freeze the components and only train reasonings.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Both input images and component images pass through the same CNN
backbone. Detection similarity and reasoning matrices then produce
class probabilities.
Parameters:
input_shape (tuple) – Shape of input images as (height, width, channels).
channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].
kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].
latent_dim (int) – Dimension of the CNN embedding space.
n_components (int) – Number of components (classless prototypes).
epsilon (float) – Convergence threshold on loss change.
random_seed (int) – Random seed for reproducibility.
distance_fn (callable, optional) – Distance function (default: squared Euclidean).
optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.
Default: ‘adam’.
transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train components.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) — sync every k steps
- ‘slow_step_size’: float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict class probabilities via softmin of distances.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
proba
Return type:
array of shape (n_samples, n_classes)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Neural Gas cooperation: all same-class prototypes participate in
the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Relevance weighting: per-feature \(\lambda_j\) learned during training
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), SRNG recovers standard GRLVQ.
Parameters:
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Neural Gas cooperation: all same-class prototypes participate in
the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Global \(\Omega\) projection:
\[d(x, w) = \|\Omega(x - w)\|^2\]
learns feature correlations and a discriminative subspace
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), SMNG recovers standard GMLVQ.
Parameters:
latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Neural Gas cooperation: all same-class prototypes participate in
the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Per-prototype \(\Omega_k\):
\[d(x, w_k) = \|\Omega_k(x - w_k)\|^2\]
learns local metrics adapted to each prototype’s region
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), SLNG recovers standard LGMLVQ.
Parameters:
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict using per-prototype \(\Omega_k\) distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
measures distance in the orthogonal complement of each prototype’s
learned invariance subspace
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), STNG recovers standard GTLVQ.
Parameters:
subspace_dim (int) – Dimension of each prototype’s tangent subspace.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Cross-Entropy LVQ with Neural Gas neighborhood cooperation.
For each class, prototypes are ranked by distance and weighted
by \(\exp(-\text{rank} / \gamma)\). The NG-weighted class distances become
logits for cross-entropy loss over all classes simultaneously.
Parameters:
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Matrix Cross-Entropy LVQ with Neural Gas neighborhood cooperation.
Combines three key ideas:
Cross-entropy loss: softmax over all-class NG-weighted distances
Neural Gas cooperation: all same-class prototypes participate,
weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Global \(\Omega\) projection:
\[d(x, w) = \|\Omega(x - w)\|^2\]
learns feature correlations and a discriminative subspace
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), MCELVQ-NG recovers a matrix CELVQ.
Parameters:
latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict calibrated probabilities using \(\Omega\)-transformed distances.
Uses NG-weighted pooling with the learned \(\Omega\) metric, matching
the training objective exactly.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
proba
Return type:
array of shape (n_samples, n_classes)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Localized Matrix Cross-Entropy LVQ with Neural Gas cooperation.
Combines three key ideas:
Cross-entropy loss: softmax over all-class NG-weighted distances
Neural Gas cooperation: all same-class prototypes participate,
weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Per-prototype \(\Omega_k\):
\[d(x, w_k) = \|\Omega_k(x - w_k)\|^2\]
learns local metrics adapted to each prototype’s region
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), LCELVQ-NG recovers a localized
matrix CELVQ.
Parameters:
latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict calibrated probabilities using per-prototype \(\Omega_k\) distances.
Uses per-class min pooling with the learned local \(\Omega_k\) metrics,
consistent with the training objective.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
proba
Return type:
array of shape (n_samples, n_classes)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
measures distance in the orthogonal complement of each prototype’s
learned invariance subspace
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), TCELVQ-NG recovers a tangent CELVQ.
Parameters:
subspace_dim (int) – Dimension of each prototype’s tangent subspace.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict calibrated probabilities using tangent distances.
Uses per-class min pooling with tangent distance, consistent
with the training objective.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
proba
Return type:
array of shape (n_samples, n_classes)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Predict with a reject option for uncertain samples.
Samples with scores in [lower, upper) are rejected.
Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Default: same
as upper (no rejection zone).
reject_label (int) – Label for rejected samples. Default: -1.
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute scores using relevance-weighted distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute scores using per-prototype Omega-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
One-Class GTLVQ with per-prototype tangent subspaces.
Each prototype learns an orthonormal basis \(\Omega_k\) that defines
directions of local invariance. Only the distance orthogonal
to this subspace is used for classification.
Parameters:
subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
The decision boundary is at score = 0.5 (where \(d = \theta\)).
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute scores using relevance-weighted distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute scores using per-prototype Omega-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute target-likeness scores using soft-weighted distances.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute target-likeness scores using soft-weighted Omega distances.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Each prototype \(k\) has its own \(\Omega_k\) matrix for local metric
adaptation, combined with probabilistic soft-weighting and
one-class threshold detection.
Parameters:
sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.
latent_dim (int, optional) – Latent space dimensionality per prototype. Default: n_features.
n_prototypes (int) – Number of prototypes for the target class. Default: 3.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute target-likeness scores using local Omega distances.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute target-likeness scores using combined Gaussian+NG weights.
Uses final (converged) gamma for NG modulation at inference time.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute target-likeness scores using combined Gaussian+NG weights.
Uses final (converged) gamma for NG modulation at inference time.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’
(default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’,
‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’.
Or pass a callable (X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
Compute target-likeness scores using combined Gaussian+NG weights.
Uses final (converged) gamma for NG modulation at inference time.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster,
JIT-compiled, but runs all max_iter iterations even after
convergence). If False, use a Python for-loop with true early
stopping (no wasted compute after convergence, but slower per
iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this
size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings:
‘stratified_random’ (default), ‘class_mean’,
‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’,
‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step
(epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default:
False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None
(uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an
update. Effective batch size = batch_size *
gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters
(0 < ema_decay < 1). After training, model parameters are
replaced with EMA-smoothed values. Typical values: 0.999,
0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train
prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or
‘bfloat16’. Master weights stay in float32; forward/backward
pass runs in lower precision for ~2x speed and ~half memory on
GPU. Float16 uses static loss scaling to prevent gradient
underflow. Default: None (disabled).
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected
(labeled reject_label) instead of being forced into a class.
Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in
[lower, upper) are rejected. Default: same as upper
(no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends SVQ-OCC with per-feature relevance weighting (like GRLVQ).
Learns which features are most important for distinguishing
target from non-target data.
Parameters:
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect
as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost.
E = alpha * R + (1 - alpha) * C. Default: 0.5.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster,
JIT-compiled, but runs all max_iter iterations even after
convergence). If False, use a Python for-loop with true early
stopping (no wasted compute after convergence, but slower per
iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this
size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings:
‘stratified_random’ (default), ‘class_mean’,
‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’,
‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step
(epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default:
False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None
(uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an
update. Effective batch size = batch_size *
gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters
(0 < ema_decay < 1). After training, model parameters are
replaced with EMA-smoothed values. Typical values: 0.999,
0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train
prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or
‘bfloat16’. Master weights stay in float32; forward/backward
pass runs in lower precision for ~2x speed and ~half memory on
GPU. Float16 uses static loss scaling to prevent gradient
underflow. Default: None (disabled).
Compute scores using relevance-weighted distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected
(labeled reject_label) instead of being forced into a class.
Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in
[lower, upper) are rejected. Default: same as upper
(no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster,
JIT-compiled, but runs all max_iter iterations even after
convergence). If False, use a Python for-loop with true early
stopping (no wasted compute after convergence, but slower per
iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this
size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings:
‘stratified_random’ (default), ‘class_mean’,
‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’,
‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step
(epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default:
False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None
(uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an
update. Effective batch size = batch_size *
gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters
(0 < ema_decay < 1). After training, model parameters are
replaced with EMA-smoothed values. Typical values: 0.999,
0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train
prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or
‘bfloat16’. Master weights stay in float32; forward/backward
pass runs in lower precision for ~2x speed and ~half memory on
GPU. Float16 uses static loss scaling to prevent gradient
underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected
(labeled reject_label) instead of being forced into a class.
Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in
[lower, upper) are rejected. Default: same as upper
(no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster,
JIT-compiled, but runs all max_iter iterations even after
convergence). If False, use a Python for-loop with true early
stopping (no wasted compute after convergence, but slower per
iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this
size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings:
‘stratified_random’ (default), ‘class_mean’,
‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’,
‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step
(epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default:
False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None
(uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an
update. Effective batch size = batch_size *
gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters
(0 < ema_decay < 1). After training, model parameters are
replaced with EMA-smoothed values. Typical values: 0.999,
0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train
prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or
‘bfloat16’. Master weights stay in float32; forward/backward
pass runs in lower precision for ~2x speed and ~half memory on
GPU. Float16 uses static loss scaling to prevent gradient
underflow. Default: None (disabled).
Compute scores using per-prototype Omega-projected distances.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected
(labeled reject_label) instead of being forced into a class.
Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in
[lower, upper) are rejected. Default: same as upper
(no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
Tangent SVQ-OCC with per-prototype tangent subspaces.
Each prototype learns an orthonormal basis \(\Omega_k\) that defines
directions of local invariance. Only the distance orthogonal
to this subspace is used for classification.
Parameters:
subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect
as the most frequent class.
alpha (float) – Balance between representation (R) and classification (C) cost.
E = alpha * R + (1 - alpha) * C. Default: 0.5.
callbacks (list, optional) – List of Callback objects.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster,
JIT-compiled, but runs all max_iter iterations even after
convergence). If False, use a Python for-loop with true early
stopping (no wasted compute after convergence, but slower per
iteration).
batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this
size.
lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’,
‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’,
‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’,
‘sgdr’. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler
(e.g. decay_rate, transition_steps). Default: None.
prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings:
‘stratified_random’ (default), ‘class_mean’,
‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’,
‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable
(X,y,n_per_class,key)->(protos,labels).
patience (int, optional) – Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step
(epsilon check). Requires use_scan=False for true early stopping.
restore_best (bool) – If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default:
False.
class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights
inversely proportional to class frequencies. Default: None
(uniform).
gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an
update. Effective batch size = batch_size *
gradient_accumulation_steps. Default: None (no accumulation).
ema_decay (float, optional) – Exponential moving average decay for parameters
(0 < ema_decay < 1). After training, model parameters are
replaced with EMA-smoothed values. Typical values: 0.999,
0.9999. Default: None (no EMA).
freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients).
E.g. [‘backbone’] to freeze the backbone and only train
prototypes. Default: None (all parameters trainable).
lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys:
- ‘sync_period’: int (default 6) – sync every k steps
- ‘slow_step_size’: float (default 0.5) – interpolation factor
Default: None (no lookahead).
mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or
‘bfloat16’. Master weights stay in float32; forward/backward
pass runs in lower precision for ~2x speed and ~half memory on
GPU. Float16 uses static loss scaling to prevent gradient
underflow. Default: None (disabled).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected
(labeled reject_label) instead of being forced into a class.
Parameters:
X (array-like of shape (n_samples, n_features))
upper (float) – Scores >= upper are classified as target. Default: 0.5.
lower (float, optional) – Scores < lower are classified as non-target. Scores in
[lower, upper) are rejected. Default: same as upper
(no rejection zone, equivalent to predict).
reject_label (int) – Label for rejected samples. Default: -1.
This implementation provides:
- Full vectorization (no Python loops)
- JIT compilation for speed
- Automatic GPU acceleration
- Immutable state management
- Numerical stability
JAX-based Possibilistic C-Means clustering with GPU acceleration.
PCM is a clustering algorithm that assigns typicality values to data points,
representing the degree to which they belong to each cluster. Unlike FCM,
the typicality of a point to one cluster is independent of its typicality
to other clusters.
Parameters:
fuzzifier (float, default=2.0) – Fuzzification parameter (\(m > 1\)). Higher values result in fuzzier
clusters.
k (float, default=1.0) – Parameter for \(\gamma\) computation. Typical values are in [0.01, 1.0].
Lower values make the algorithm more sensitive to outliers.
init_method ({'fcm', 'random'}, default='fcm') – Initialization method:
- ‘fcm’: Initialize using FCM results (recommended)
- ‘random’: Random initialization
n_clusters (int) – Number of clusters (must be >= 2).
FPCM maintains TWO matrices: \(U\) (fuzzy membership) and \(T\) (typicality).
\(U\) has row-sum-to-1 constraint (standard FCM), while \(T\) has column-sum-to-1
constraint per the original Pal, Pal & Bezdek (1997) formulation.
Algorithm:
Initialize \(U\) and \(T\) (randomly or using FCM)
Update centroids using combined fuzzy and typicality weights
Update \(U\) using FCM rule with fuzzifier \(m\) (row-normalized)
AFCM is an adaptive variant that combines fuzzy and possibilistic approaches
with specific parameter combinations.
Key features:
- Centroids use \(a \cdot U^m + b \cdot T\) (\(T\) to power 1, not \(m\)!)
- \(\gamma\) computed with Euclidean distance (not squared)
- Exponential \(T\) update with parameter \(b\)
- Standard FCM \(U\) update
Algorithm:
Initialize \(U\) using FCM
Compute \(\gamma\) parameters using Euclidean distance
Update \(T\) using exponential update
Update \(U\) using standard FCM rule
Update centroids using combined fuzzy-possibilistic weights
Improved Possibilistic C-Means clustering with JAX.
IPCM uses a two-phase approach to improve clustering performance:
- Phase 0: Initialize \(\gamma\) using fuzzy membership only
- Phase 1: Refine \(\gamma\) using both membership and typicality
Key differences from PCM:
- Uses product of \(U^{m_f}\) and \(T^{m_p}\) in centroid computation
- Modified \(U\) update that depends on \(T\)
- Two-phase \(\gamma\) computation
Algorithm (Phase 0):
Initialize \(U\) using FCM, \(T = 0\)
Compute \(\gamma\) parameters from fuzzy membership
Improved Possibilistic C-Means 2 clustering with JAX.
IPCM2 is a variant of IPCM with key differences:
- Uses exponential \(T\) update: \(t_{ij} = \exp(-d_{ij}^2 / \gamma_j)\)
- Centroids use \(U^{m_f} \cdot T\) (\(T\) without power!)
- Modified \(U\) update with exponential distance
- Different objective function
Algorithm (Phase 0):
Initialize \(U\) using FCM, \(T = 0\)
Compute \(\gamma\) parameters from fuzzy membership
KFCM uses a Gaussian kernel to map data into a high-dimensional feature space
where clustering is performed. This allows handling non-linearly separable data.
KFPCM maintains two matrices (\(U\) and \(T\)) in kernel space. \(U\) has row-sum-to-1
constraint (standard FCM), while \(T\) has column-sum-to-1 constraint per the
original Pal, Pal & Bezdek (1997) FPCM formulation.
Parameters:
fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).
eta (float, default=2.0) – Fuzziness parameter for \(T\) matrix (must be > 1.0).
Kernel distances are bounded in \([0, 2]\), so thresholds
\(\theta_k\) are initialized in kernel distance scale.
Parameters:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-prototype median distance from prototype
to target class members.
‘mean’: per-prototype mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
where \(\lambda = \text{softmax}(\text{relevances})\) are
learned per-feature weights.
Parameters:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-prototype median distance from prototype
to target class members.
‘mean’: per-prototype mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
n_prototypes (int) – Number of prototypes for the target class.
target_label (int, optional) – Target (normal) class label. Default: auto-detect.
Return the learned relevance weights (normalized via softmax).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-prototype median distance from prototype
to target class members.
‘mean’: per-prototype mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
Compute target-likeness scores using kernel distance.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
The decision boundary is at score = 0.5 (where \(d_\kappa = \theta\)).
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
where \(\lambda = \text{softmax}(\text{relevances})\).
Parameters:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-prototype median distance from prototype
to target class members.
‘mean’: per-prototype mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
Compute target-likeness scores using relevance-weighted kernel distance.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Compute target-likeness scores using exponential kernel distance.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Parameters:
X (array-like of shape (n_samples, n_features))
Returns:
scores
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
GLVQ with Gaussian kernel distance in feature space. Each prototype
\(w_k\) has a learnable bandwidth \(\sigma_k\) that is
adapted via gradient descent alongside the prototypes.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median distance from prototype to class members.
‘mean’: per-class mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
beta (float) – Transfer function steepness parameter.
n_prototypes_per_class (int) – Number of prototypes per class.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median distance from prototype to class members.
‘mean’: per-class mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
beta (float) – Transfer function steepness parameter.
n_prototypes_per_class (int) – Number of prototypes per class.
Predict using learned kernel distance with relevance weighting.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Predict using learned exponential kernel distance.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Differentiating Kernel GLVQ with Neural Gas cooperation.
Combines Gaussian kernel distance with per-prototype bandwidth
adaptation and Neural Gas rank-weighted loss. All same-class
prototypes participate, weighted by \(\exp(-\text{rank} / \gamma)\).
When \(\gamma \to 0\), DKGLVQ-NG recovers standard DKGLVQ.
Parameters:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median distance from prototype
to class members.
‘mean’: per-class mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
beta (float) – Transfer function steepness. Default: 10.0.
gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
n_prototypes_per_class (int) – Number of prototypes per class.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
where \(\lambda = \text{softmax}(\text{relevances})\).
When \(\gamma \to 0\), DKGRLVQ-NG recovers standard DKGRLVQ.
Parameters:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median distance.
‘mean’: per-class mean distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
beta (float) – Transfer function steepness. Default: 10.0.
gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay for gamma.
n_prototypes_per_class (int) – Number of prototypes per class.
Return the learned relevance weights (normalized via softmax).
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Neural Gas with Gaussian kernel distance for ranking. The kernel
bandwidth \(\sigma\) is a fixed hyperparameter (not learned).
The competitive Hebbian update rule operates in the original data
space — only the distance metric changes.
Standard Kohonen SOM with Gaussian kernel distance for BMU selection.
The kernel bandwidth \(\sigma\) is a fixed hyperparameter (not learned).
The grid-based neighborhood and competitive update rule operate in the
original data space — only the data-space distance metric changes.
Heskes SOM with Gaussian kernel distance. The kernel bandwidth
\(\sigma\) is a fixed hyperparameter (not learned). The Heskes
BMU criterion and batch update operate in the original data space —
only the data-space distance metric changes.
The Heskes BMU criterion uses kernel distance:
\[c^*(x) = \arg\min_c \sum_k h(k, c) \cdot d_\kappa^2(x, w_k)\]
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
Uses a modified BMU definition that considers the neighborhood
structure, and a pure batch update (weighted average of data).
Guarantees monotonic decrease of the Heskes energy function.
Differences from KohonenSOM:
BMU is chosen to minimize neighborhood-weighted distance sum,
not raw distance to closest prototype.
Prototypes are updated via weighted average (no learning rate).
use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping.
patience (int, optional) – Epochs with no improvement before early stopping. Default: None.
restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.
Neural Gas cooperation: all same-class prototypes participate in
the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)
Geodesic distance: \(d(x, w)\) computed via the manifold’s
intrinsic metric (matrix logarithm + Frobenius norm)
Prototypes live on the manifold and are updated via projected gradient
descent: optax computes Euclidean gradients, then
manifold.project() maps prototypes back to the manifold after
each step.
The neighborhood range \(\gamma\) decays during training from
\(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).
When \(\gamma \to 0\), RiemannianSRNG recovers a Riemannian GLVQ.
Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.
beta (float) – Transfer function steepness parameter for sigmoid shaping.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
Default: max prototypes per class / 2.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
Default: computed from max_iter so gamma reaches gamma_final.
tau (float) – Injectivity radius safety factor for manifold projection.
Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Predict using per-prototype tangent-space omega metric.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSRNG with per-prototype tangent subspace projection.
Each prototype \(w_k\) has an orthonormal basis \(\Omega_k\)
defining an invariant subspace; the distance measures how far the
tangent vector lies outside this subspace:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSRNG with Gaussian kernel distance. Each prototype
\(w_k\) has a learnable bandwidth \(\sigma_k\) that controls
the sensitivity range of the kernel on the manifold.
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median geodesic distance.
‘mean’: per-class mean geodesic distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
beta (float) – Transfer function steepness parameter.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
tau (float) – Injectivity radius safety factor. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
Predict class labels using kernel-wrapped geodesic distance.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSRNG with relevance-weighted Gaussian kernel
distance in tangent space. Each prototype \(w_k\) has a learnable
bandwidth \(\sigma_k\), and a shared relevance vector
\(\lambda\) weights the tangent space features.
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median geodesic distance.
‘mean’: per-class mean geodesic distance.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.
beta (float) – Transfer function steepness parameter.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
tau (float) – Injectivity radius safety factor. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
Predict using relevance-weighted kernel distance in tangent space.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSRNG with exponential kernel distance applied in
tangent space. A global transformation matrix \(\hat\Omega\)
(d_flat x latent_dim) is learned such that:
where \(v = \text{Log}_{w_k}(x)_{\text{flat}}\). Since the
prototype maps to the zero vector in its own tangent space, the
exponential kernel simplifies from the full three-term formula to
\(\exp(v^T \hat\Lambda v) - 1\).
Parameters:
manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.
latent_dim (int, optional) – Dimensionality of the transformation. If None, uses d_flat.
omega_hat_scale (float) – Scale factor for omega_hat initialization. Default: 0.1.
Smaller values prevent exp overflow at initialization.
beta (float) – Transfer function steepness parameter.
gamma_init (float, optional) – Initial neighborhood range for NG cooperation.
gamma_final (float) – Final neighborhood range. Default: 0.01.
gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.
tau (float) – Injectivity radius safety factor. Default: 0.95.
n_prototypes_per_class (int) – Number of prototypes per class.
Predict using exponential kernel distance in tangent space.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median of omega-projected distances.
‘mean’: per-class mean.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
latent_dim (int, optional) – Projection dimensionality for omega. Default: d_flat.
Predict using kernel-wrapped omega-projected tangent distance.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median of local-omega distances.
‘mean’: per-class mean.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
latent_dim (int, optional) – Projection dimensionality for local omegas. Default: d_flat.
Predict using kernel-wrapped local-omega tangent distance.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSTNG with a Gaussian kernel wrapping the tangent
subspace residual distance. Each prototype has an orthonormal
subspace basis \(\Omega_k\) and a learnable bandwidth
\(\sigma_k\):
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median of subspace residual distances.
‘mean’: per-class mean.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
Predict using kernel-wrapped subspace residual distance.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSMNG with a relevance-weighted Gaussian kernel in
the omega-projected tangent space. Each prototype has a learnable
bandwidth \(\sigma_k\), and a shared relevance vector
\(\lambda\) weights the projected features:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median of relevance-weighted distances.
‘mean’: per-class mean.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
latent_dim (int, optional) – Projection dimensionality for omega. Default: d_flat.
Predict using relevance-weighted kernel in omega-projected space.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSMNG with an exponential kernel in the omega-projected
tangent space. The global omega \(\Omega\) projects tangent vectors
from d_flat to latent_dim, then a learned matrix
\(\hat\Lambda = \hat\Omega \hat\Omega^T\) provides further
metric adaptation in the projected space:
Predict using exponential kernel in omega-projected space.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSLNG with a relevance-weighted Gaussian kernel in
the per-prototype omega-projected tangent space. Each prototype has
its own metric matrix \(\Omega_k\) and bandwidth \(\sigma_k\),
plus a shared relevance vector \(\lambda\):
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median of relevance-weighted distances.
‘mean’: per-class mean.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
latent_dim (int, optional) – Projection dimensionality for local omegas. Default: d_flat.
Predict using relevance-weighted kernel in local-omega space.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSLNG with an exponential kernel in the per-prototype
omega-projected tangent space. Each prototype has its own metric matrix
\(\Omega_k\), and a shared \(\hat\Lambda = \hat\Omega
\hat\Omega^T\) provides further metric adaptation in the projected space:
Predict using exponential kernel in local-omega-projected space.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSTNG with a relevance-weighted Gaussian kernel on
the tangent subspace residual. Each prototype has an orthonormal
subspace basis \(\Omega_k\) and a learnable bandwidth
\(\sigma_k\), plus a shared relevance vector \(\lambda\)
that weights features of the residual:
sigma_init (str or float) – Initialization strategy for per-prototype bandwidths.
‘median’ (default): per-class median of relevance-weighted residuals.
‘mean’: per-class mean.
float: fixed value for all prototypes.
sigma_min (float) – Lower bound for sigma. Default: 1e-3.
Predict using relevance-weighted kernel on subspace residual.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
Extends RiemannianSTNG with an exponential kernel on the tangent
subspace residual. Each prototype has an orthonormal subspace basis
\(\Omega_k\), and a shared \(\hat\Lambda = \hat\Omega
\hat\Omega^T\) provides metric adaptation on the residual:
\[d_\kappa^2(x, w_k) = \exp\left(
r^T \hat\Lambda r
\right) - 1\]
Predict using exponential kernel on subspace residual.
Parameters:
X (array-like of shape (n_samples, n_features_flat))
Returns:
labels
Return type:
array of shape (n_samples,)
fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)¶
Fit the model.
Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
y (array-like of shape (n_samples,)) – Target labels.
initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.
initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes
have a different number than what n_prototypes_per_class produces.
validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True,
the model restores params with the lowest validation loss.
sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.
resume (bool, default=False) – If True, resume training from the model’s current fitted state.
Uses current prototypes (and other fitted params) as starting point
with a fresh optimizer. Cannot be combined with initial_prototypes.
KNN is a simple, non-parametric classifier that predicts based on the
k nearest training samples.
Algorithm:
For each test sample, compute distances to all training samples,
find k nearest neighbors, predict label as mode (most common)
of k neighbors’ labels, and compute probability as frequency
of predicted label.
Parameters:
n_neighbors (int, default=5) – Number of neighbors to use
NPC is a supervised prototype-based classifier that iteratively optimizes
prototypes based on accuracy metric. Uses softmin for probability estimation.
Algorithm:
Initialize prototypes (one per class)
Predict labels using nearest prototype
Compute accuracy
If accuracy >= threshold or max_iter reached, stop
SOM is an unsupervised learning algorithm that creates a low-dimensional
(typically 2D) representation of high-dimensional data while preserving
topological relationships.
Algorithm:
Initialize grid of neurons with random weights
For each iteration, select random sample from data,
find Best Matching Unit (BMU), update BMU and its neighbors
towards the sample, and decay learning rate and neighborhood range.
Parameters:
grid_size (int, optional) – Size of the SOM grid (grid_size x grid_size).
If None, computed as int(sqrt(5 * sqrt(n_samples)))
max_iter (int, optional) – Number of training iterations.
If None, set to 500 * grid_size^2