Models API Reference

All models are importable from prosemble.models.

Supervised LVQ

class prosemble.models.GLVQ(beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Generalized Learning Vector Quantization.

Loss:

\[\mu = \frac{d^+ - d^-}{d^+ + d^-}\]

with optional transfer function.

Parameters:
  • beta (float) – Parameter \(\beta\) for transfer function (e.g., sigmoid steepness).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

classmethod load(path)

Load a fitted model from an NPZ file.

Parameters:

path (str) – Path to the .npz file.

Returns:

Reconstructed fitted model.

Return type:

model

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

save(path, quantize=None)

Save fitted model to an NPZ file.

Parameters:
  • path (str) – File path (.npz extension added if not present).

  • quantize (str, optional) – Quantize before saving: 'float16', 'bfloat16', or 'int8'. The model in memory is unchanged.

Return type:

None

class prosemble.models.GRLVQ(beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Generalized Relevance Learning Vector Quantization.

Learns per-feature relevance weights \(\lambda_j\) such that the weighted distance is:

\[d(x, w) = \sum_j \lambda_j (x_j - w_j)^2\]
Parameters:
  • beta (float) – Transfer function steepness parameter.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.GMLVQ(latent_dim=None, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Generalized Matrix Learning Vector Quantization.

Learns a global linear mapping \(\Omega\) (d x latent_dim) such that distances are computed in the transformed space:

\[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]

The relevance matrix \(\Lambda = \Omega^T \Omega\) captures feature correlations.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the latent space. If None, uses input dim.

  • beta (float) – Transfer function steepness.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using learned \(\Omega\) distance.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.LGMLVQ(latent_dim=None, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Localized Generalized Matrix Learning Vector Quantization.

Each prototype \(k\) has its own \(\Omega_k\) matrix. The distance from sample \(x\) to prototype \(w_k\) is:

\[d(x, w_k) = (x - w_k)^T \Omega_k^T \Omega_k (x - w_k)\]
Parameters:
  • latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.

  • beta (float) – Transfer function steepness.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using local \(\Omega\) distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.GTLVQ(subspace_dim=2, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Generalized Tangent Learning Vector Quantization.

Each prototype \(k\) has a subspace basis \(\Omega_k\). The tangent distance is:

\[d(x, w_k) = \|P_k(x - w_k)\|^2\]

where \(P_k = I - \Omega_k \Omega_k^T\) is the orthogonal projector.

Parameters:
  • subspace_dim (int) – Dimension of each prototype’s tangent subspace.

  • beta (float) – Transfer function steepness.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.CELVQ(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, gradient_checkpointing=False, devices=None)[source]

Cross-Entropy Learning Vector Quantization.

Computes per-class minimum distances, negates them to get logits, then applies cross-entropy loss against true labels.

Parameters:
  • n_prototypes_per_class (int) – Prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float)

  • random_seed (int)

  • margin (float)

  • callbacks (list)

  • use_scan (bool)

  • batch_size (int | None)

  • lr_scheduler_kwargs (dict | None)

  • patience (int | None)

  • restore_best (bool)

  • gradient_accumulation_steps (int | None)

  • ema_decay (float | None)

  • freeze_params (list | None)

  • lookahead (dict | None)

  • mixed_precision (str | None)

  • gradient_checkpointing (bool)

  • devices (list | None)

See also

SupervisedPrototypeModel

Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.LVQ1(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Learning Vector Quantization 1.

For each sample:

  • Find nearest prototype (winner)

  • If same class: \(w \leftarrow w + \eta (x - w)\) (attract)

  • If diff class: \(w \leftarrow w - \eta (x - w)\) (repel)

Uses batch updates (all samples per iteration).

Parameters:
  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None)[source]

Fit LVQ1 using competitive learning (no gradients).

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.LVQ21(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Learning Vector Quantization 2.1.

For each sample:

  • Find closest same-class prototype \(w^+\) and closest different-class \(w^-\)

  • \(w^+ \leftarrow w^+ + \eta (x - w^+)\) (attract \(w^+\))

  • \(w^- \leftarrow w^- - \eta (x - w^-)\) (repel \(w^-\))

Parameters:
  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None)[source]

Fit LVQ2.1 using competitive learning.

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.MedianLVQ(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Median Learning Vector Quantization.

Prototypes are always actual data points. The algorithm alternates:

  1. E-step: compute soft assignments (GLVQ-like weights)

  2. M-step: for each prototype, find the data point that minimizes the weighted sum of distances

Parameters:
  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None)[source]

Fit MedianLVQ.

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.SLVQ(sigma=1.0, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Soft Learning Vector Quantization.

Uses Gaussian mixture probabilities:

\[p(k|x) = \frac{\exp(-d^2 / 2\sigma^2)}{\sum_j \exp(-d_j^2 / 2\sigma^2)}\]
\[P(\text{class}|x) = \sum_{k \in \text{class}} p(k|x)\]

Loss: \(-\log(P(\text{correct}) / P(\text{wrong}))\)

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture.

  • rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using predict_with_rejection(). Default is None (no rejection).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.RSLVQ(sigma=1.0, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Robust Soft Learning Vector Quantization.

Like SLVQ but with a more robust denominator:

\[\text{loss} = -\log\frac{P(\text{correct}|x)}{P(\text{all}|x)}\]
Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture.

  • rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using predict_with_rejection(). Default is None (no rejection).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.MRSLVQ(sigma=1.0, latent_dim=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Matrix Robust Soft Learning Vector Quantization.

Combines the RSLVQ probabilistic loss with a learned global linear mapping \(\Omega\) (d x latent_dim) for metric adaptation:

\[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]

The relevance matrix \(\Lambda = \Omega^T \Omega\) captures feature correlations in the probabilistic framework.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture.

  • latent_dim (int, optional) – Dimensionality of the latent space. If None, uses input dim.

  • rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using predict_with_rejection(). Default is None (no rejection).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using learned \(\Omega\) distance.

predict_proba(X)[source]

Predict class probabilities using \(\Omega\)-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.LMRSLVQ(sigma=1.0, latent_dim=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Localized Matrix Robust Soft Learning Vector Quantization.

Each prototype \(k\) has its own \(\Omega_k\) matrix. The distance from sample \(x\) to prototype \(w_k\) is:

\[d(x, w_k) = (x - w_k)^T \Omega_k^T \Omega_k (x - w_k)\]

Combined with the RSLVQ probabilistic loss for metric-adaptive soft classification with local relevance learning.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture.

  • latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.

  • rejection_confidence (float, optional) – Minimum class probability for a confident prediction (0 to 1). Samples below this threshold are rejected (labeled -1) when using predict_with_rejection(). Default is None (no rejection).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using local \(\Omega\) distances.

predict_proba(X)[source]

Predict class probabilities using local \(\Omega\)-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Robust Soft LVQ with Neural Gas Cooperation.

Combines:

  • RSLVQ probabilistic loss: \(-\log(P(\text{correct}|x))\)

  • Neural Gas cooperation: all prototypes weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Euclidean distance

The NG neighborhood modulates RSLVQ’s Gaussian probabilities, emphasizing nearby prototypes. \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\).

Parameters:
  • sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1). Samples below this threshold are rejected (label -1).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.MRSLVQ_NG(sigma=1.0, latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Matrix Robust Soft LVQ with Neural Gas Cooperation.

Combines:

  • RSLVQ probabilistic loss: \(-\log(P(\text{correct}|x))\)

  • Neural Gas cooperation: all prototypes weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Global \(\Omega\) matrix for metric adaptation:

    \[d(x, w) = (x - w)^T \Omega^T \Omega (x - w)\]
Parameters:
  • sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.

  • latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1). Samples below this threshold are rejected (label -1).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using learned \(\Omega\) distance.

predict_proba(X)[source]

Predict class probabilities using \(\Omega\)-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.LMRSLVQ_NG(sigma=1.0, latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, rejection_confidence=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Localized Matrix Robust Soft LVQ with Neural Gas Cooperation.

Each prototype \(k\) has its own \(\Omega_k\) matrix. Combined with RSLVQ probabilistic loss and NG rank-based neighborhood cooperation.

Parameters:
  • sigma (float) – Bandwidth for RSLVQ Gaussian mixture probability computation.

  • latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • rejection_confidence (float, optional) – Minimum class probability for confident prediction (0 to 1). Samples below this threshold are rejected (label -1).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using local \(\Omega_k\) distances.

predict_proba(X)[source]

Predict class probabilities using local \(\Omega_k\)-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.GLVQ1(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, gradient_checkpointing=False, devices=None)[source]

GLVQ with LVQ1-style loss (gradient-based).

Loss: \(d^+\) when correct, \(-d^-\) when wrong.

Parameters:
  • n_prototypes_per_class (int) – Prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float)

  • random_seed (int)

  • margin (float)

  • callbacks (list)

  • use_scan (bool)

  • batch_size (int | None)

  • lr_scheduler_kwargs (dict | None)

  • patience (int | None)

  • restore_best (bool)

  • gradient_accumulation_steps (int | None)

  • ema_decay (float | None)

  • freeze_params (list | None)

  • lookahead (dict | None)

  • mixed_precision (str | None)

  • gradient_checkpointing (bool)

  • devices (list | None)

See also

SupervisedPrototypeModel

Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.GLVQ21(n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, gradient_checkpointing=False, devices=None)[source]

GLVQ with LVQ2.1-style loss (gradient-based, unnormalized).

Loss: \(d^+ - d^-\) (no normalization by \(d^+ + d^-\)).

Parameters:
  • n_prototypes_per_class (int) – Prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float)

  • random_seed (int)

  • margin (float)

  • callbacks (list)

  • use_scan (bool)

  • batch_size (int | None)

  • lr_scheduler_kwargs (dict | None)

  • patience (int | None)

  • restore_best (bool)

  • gradient_accumulation_steps (int | None)

  • ema_decay (float | None)

  • freeze_params (list | None)

  • lookahead (dict | None)

  • mixed_precision (str | None)

  • gradient_checkpointing (bool)

  • devices (list | None)

See also

SupervisedPrototypeModel

Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

Deep and Siamese Variants

class prosemble.models.LVQMLN(hidden_sizes=None, latent_dim=2, activation='sigmoid', beta=10.0, bb_lr=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

LVQ Multi-Layer Network.

An MLP backbone maps inputs into a latent space. Prototypes reside directly in that latent space. The GLVQ loss trains both the backbone and the prototypes jointly via gradient descent.

Architecture:

Input (d) -> MLP -> Latent (latent_dim)
                      |
                      v
            distance(latent_x, prototypes)
                      |
                      v
                  GLVQ loss
Parameters:
  • hidden_sizes (list of int) – Sizes of hidden layers. e.g. [10] for one hidden layer of 10 units.

  • latent_dim (int) – Dimension of the latent/embedding space where prototypes live.

  • activation (str) – Activation function: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels.

Transforms X through the backbone, then finds nearest prototype.

predict_proba(X)[source]

Predict class probabilities.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.PLVQ(hidden_sizes=None, latent_dim=2, activation='sigmoid', sigma=1.0, loss_type='rslvq', n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Probabilistic LVQ with learned nonlinear transformation.

Combines an MLP backbone (learned metric) with probabilistic soft assignment via Gaussian mixtures. The loss is the negative log-likelihood of the correct class:

\[p(k|x) = \frac{\exp(-d(f(x), w_k)^2 / 2\sigma^2)}{Z}\]
\[P(\text{class}|x) = \sum_{k \in \text{class}} p(k|x)\]
\[\text{loss} = -\log\frac{P(\text{correct}|x)}{P(\text{all}|x)}\]
Parameters:
  • hidden_sizes (list of int) – Hidden layer sizes.

  • latent_dim (int) – Latent space dimension.

  • activation (str) – Activation: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • sigma (float) – Bandwidth of the Gaussian mixture.

  • loss_type (str) – ‘rslvq’ (robust, default) or ‘nllr’ (likelihood ratio).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict via most probable class.

predict_proba(X)[source]

Predict class probabilities via Gaussian mixture.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.SiameseGLVQ(hidden_sizes=None, latent_dim=2, activation='sigmoid', beta=10.0, bb_lr=None, both_path_gradients=True, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Siamese GLVQ — GLVQ with a learned embedding network.

Both inputs and prototypes are transformed through the same MLP backbone before computing squared Euclidean distances.

Parameters:
  • hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.

  • latent_dim (int) – Dimension of the embedding space.

  • activation (str) – Activation function for the backbone MLP. Supported values: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.

  • both_path_gradients (bool) – If True, compute gradients through both input and prototype paths. If False, prototype path gradients are stopped. Default: True.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)[source]

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.SiameseGMLVQ(hidden_sizes=None, latent_dim=2, omega_dim=None, activation='sigmoid', beta=10.0, bb_lr=None, both_path_gradients=True, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Siamese GMLVQ — GMLVQ with a learned embedding network.

Both inputs and prototypes are transformed through the same MLP, then distances are computed using a learned \(\Omega\) matrix in the latent space:

\[d = \|\Omega(f(x) - f(w))\|^2\]
Parameters:
  • hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.

  • latent_dim (int) – Dimension of the backbone output (embedding space).

  • omega_dim (int, optional) – Omega mapping dimension (number of rows in Omega). If None, uses latent_dim (square matrix). Default: None.

  • activation (str) – Activation function for the backbone MLP. Supported values: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.

  • both_path_gradients (bool) – If True, compute gradients through both input and prototype paths. If False, prototype path gradients are stopped. Default: True.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.SiameseGTLVQ(hidden_sizes=None, latent_dim=4, subspace_dim=2, activation='sigmoid', beta=10.0, bb_lr=None, both_path_gradients=True, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Siamese GTLVQ — GTLVQ with a learned embedding network.

Both inputs and prototypes are transformed through the same MLP, then tangent distances are computed in the latent space using per-prototype subspace bases.

Parameters:
  • hidden_sizes (list of int) – Hidden layer sizes for the backbone MLP.

  • latent_dim (int) – Dimension of the backbone output (embedding space).

  • subspace_dim (int) – Tangent subspace dimension per prototype. Each prototype gets a learned orthonormal basis of this rank in latent space.

  • activation (str) – Activation function for the backbone MLP. Supported values: ‘sigmoid’, ‘relu’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.

  • both_path_gradients (bool) – If True, compute gradients through both input and prototype paths. If False, prototype path gradients are stopped. Default: True.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

Image LVQ

class prosemble.models.ImageGLVQ(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, activation='relu', beta=10.0, bb_lr=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Image GLVQ — GLVQ with a CNN embedding network.

Both input images and prototype images are passed through the same CNN backbone before computing squared Euclidean distances.

Parameters:
  • input_shape (tuple) – Shape of input images as (height, width, channels).

  • channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].

  • kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].

  • latent_dim (int) – Dimension of the CNN embedding space.

  • activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • bb_lr (float, optional) – Separate learning rate for the backbone network. If None, uses the same lr as prototypes. Default: None.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)[source]

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.ImageGMLVQ(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, omega_dim=None, activation='relu', beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Image GMLVQ — GMLVQ with a CNN embedding network.

Like ImageGLVQ but with a learned \(\Omega\) matrix in latent space:

\[d = \|\Omega(\text{CNN}(x) - \text{CNN}(w))\|^2\]
Parameters:
  • input_shape (tuple) – Shape of input images as (height, width, channels).

  • channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].

  • kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].

  • latent_dim (int) – Dimension of the CNN embedding space.

  • omega_dim (int, optional) – Omega mapping dimension (number of rows in Omega). If None, uses latent_dim (square matrix). Default: None.

  • activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.ImageGTLVQ(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, subspace_dim=2, activation='relu', beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Image GTLVQ — GTLVQ with a CNN embedding network.

Like ImageGLVQ but with per-prototype tangent subspace bases in latent space:

\[d = \|P_k(\text{CNN}(x) - \text{CNN}(w_k))\|^2\]
Parameters:
  • input_shape (tuple) – Shape of input images as (height, width, channels).

  • channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].

  • kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].

  • latent_dim (int) – Dimension of the CNN embedding space.

  • subspace_dim (int) – Tangent subspace dimension per prototype. Each prototype gets a learned orthonormal basis of this rank in latent space.

  • activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • beta (float) – Transfer function parameter for GLVQ loss.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

Classification-By-Components

class prosemble.models.CBC(n_components=5, n_classes=2, sigma=1.0, margin=0.3, components_initializer=None, reasonings_initializer=None, similarity_fn=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Classification-By-Components.

Components detect patterns in the input (via similarity), then reasoning matrices determine how each detection contributes evidence for/against each class.

Parameters:
  • n_components (int) – Number of components (analogous to prototypes, but classless).

  • n_classes (int) – Number of output classes.

  • sigma (float) – Bandwidth for Gaussian similarity in component detection.

  • margin (float) – Margin for the margin loss.

  • components_initializer (callable, optional) – Initializer for component vectors. Signature: (X, key, n_components) -> components. Default: None (selects random training samples).

  • reasonings_initializer (callable, optional) – Initializer for the reasoning matrix. Signature: (n_components, n_classes, key) -> reasonings. Default: None (initializes near-uniform with small noise).

  • similarity_fn (callable, optional) – Similarity function for component detection. Default: None (uses Gaussian similarity).

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘components’] to freeze the components and only train reasonings. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels.

predict_proba(X)[source]

Predict class probabilities via CBC reasoning.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.ImageCBC(input_shape=(28, 28, 1), channels=None, kernel_sizes=None, latent_dim=32, n_components=5, n_classes=2, sigma=1.0, activation='relu', margin=0.3, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Image CBC — CBC with a CNN embedding network.

Both input images and component images pass through the same CNN backbone. Detection similarity and reasoning matrices then produce class probabilities.

Parameters:
  • input_shape (tuple) – Shape of input images as (height, width, channels).

  • channels (list of int) – CNN output channels per convolutional layer, e.g. [16, 32].

  • kernel_sizes (list of int) – Kernel sizes per convolutional layer, e.g. [3, 3].

  • latent_dim (int) – Dimension of the CNN embedding space.

  • n_components (int) – Number of components (classless prototypes).

  • n_classes (int) – Number of output classes.

  • sigma (float) – Bandwidth for Gaussian similarity in component detection.

  • activation (str) – Activation function for the CNN backbone. Supported values: ‘relu’, ‘sigmoid’, ‘tanh’, ‘leaky_relu’, ‘selu’.

  • margin (float) – Margin for the margin loss.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train components. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) — sync every k steps - ‘slow_step_size’: float (default 0.5) — interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)[source]

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

Supervised Neural Gas

class prosemble.models.SRNG(beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Relevance Neural Gas.

Combines three key ideas:

  • GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification

  • Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Relevance weighting: per-feature \(\lambda_j\) learned during training

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), SRNG recovers standard GRLVQ.

Parameters:
  • beta (float) – Transfer function steepness parameter for sigmoid shaping.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.SMNG(latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Matrix Neural Gas.

Combines three key ideas:

  • GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification

  • Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Global \(\Omega\) projection:

    \[d(x, w) = \|\Omega(x - w)\|^2\]

    learns feature correlations and a discriminative subspace

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), SMNG recovers standard GMLVQ.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.

  • beta (float) – Transfer function steepness parameter for sigmoid shaping.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using learned \(\Omega\) distance.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.SLNG(latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Localized Matrix Neural Gas.

Combines three key ideas:

  • GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification

  • Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Per-prototype \(\Omega_k\):

    \[d(x, w_k) = \|\Omega_k(x - w_k)\|^2\]

    learns local metrics adapted to each prototype’s region

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), SLNG recovers standard LGMLVQ.

Parameters:
  • latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.

  • beta (float) – Transfer function steepness parameter for sigmoid shaping.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using per-prototype \(\Omega_k\) distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.STNG(subspace_dim=2, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Tangent Neural Gas.

Combines three key ideas:

  • GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification

  • Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Tangent subspaces:

    \[d(x, w_k) = \|(I - \Omega_k \Omega_k^T)(x - w_k)\|^2\]

    measures distance in the orthogonal complement of each prototype’s learned invariance subspace

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), STNG recovers standard GTLVQ.

Parameters:
  • subspace_dim (int) – Dimension of each prototype’s tangent subspace.

  • beta (float) – Transfer function steepness parameter for sigmoid shaping.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using tangent distance.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

Cross-Entropy Neural Gas

class prosemble.models.CELVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Cross-Entropy LVQ with Neural Gas neighborhood cooperation.

For each class, prototypes are ranked by distance and weighted by \(\exp(-\text{rank} / \gamma)\). The NG-weighted class distances become logits for cross-entropy loss over all classes simultaneously.

Parameters:
  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict class labels via Winner-Takes-All Competition.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.MCELVQ_NG(latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Matrix Cross-Entropy LVQ with Neural Gas neighborhood cooperation.

Combines three key ideas:

  • Cross-entropy loss: softmax over all-class NG-weighted distances

  • Neural Gas cooperation: all same-class prototypes participate, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Global \(\Omega\) projection:

    \[d(x, w) = \|\Omega(x - w)\|^2\]

    learns feature correlations and a discriminative subspace

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), MCELVQ-NG recovers a matrix CELVQ.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the \(\Omega\) projection space. If None, uses input dim.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using learned \(\Omega\) distance.

predict_proba(X)[source]

Predict calibrated probabilities using \(\Omega\)-transformed distances.

Uses NG-weighted pooling with the learned \(\Omega\) metric, matching the training objective exactly.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.LCELVQ_NG(latent_dim=None, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Localized Matrix Cross-Entropy LVQ with Neural Gas cooperation.

Combines three key ideas:

  • Cross-entropy loss: softmax over all-class NG-weighted distances

  • Neural Gas cooperation: all same-class prototypes participate, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Per-prototype \(\Omega_k\):

    \[d(x, w_k) = \|\Omega_k(x - w_k)\|^2\]

    learns local metrics adapted to each prototype’s region

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), LCELVQ-NG recovers a localized matrix CELVQ.

Parameters:
  • latent_dim (int, optional) – Latent space dimensionality per prototype. If None, uses input dim.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using per-prototype \(\Omega_k\) distances.

predict_proba(X)[source]

Predict calibrated probabilities using per-prototype \(\Omega_k\) distances.

Uses per-class min pooling with the learned local \(\Omega_k\) metrics, consistent with the training objective.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.TCELVQ_NG(subspace_dim=2, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Tangent Cross-Entropy LVQ with Neural Gas neighborhood cooperation.

Combines three key ideas:

  • Cross-entropy loss: softmax over all-class NG-weighted distances

  • Neural Gas cooperation: all same-class prototypes participate, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Tangent subspaces:

    \[d(x, w_k) = \|(I - \Omega_k \Omega_k^T)(x - w_k)\|^2\]

    measures distance in the orthogonal complement of each prototype’s learned invariance subspace

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), TCELVQ-NG recovers a tangent CELVQ.

Parameters:
  • subspace_dim (int) – Dimension of each prototype’s tangent subspace.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict using tangent distance.

predict_proba(X)[source]

Predict calibrated probabilities using tangent distances.

Uses per-class min pooling with tangent distance, consistent with the training objective.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

One-Class GLVQ

class prosemble.models.OCGLVQ(n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Generalized Learning Vector Quantization.

Combines GLVQ’s \(\mu\)-based hypothesis testing with per-prototype visibility thresholds \(\theta_k\) for one-class classification.

Parameters:
  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.

  • beta (float) – Sigmoid steepness for the transfer function. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

thetas_

Learned per-prototype visibility thresholds.

Type:

array of shape (n_prototypes,)

decision_function(X)[source]

Compute target-likeness scores.

Scores near 1.0 indicate target class, near 0.0 indicate outlier. The decision boundary is at score = 0.5 (where \(d = \theta\)).

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

predict(X)[source]

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)[source]

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.OCGRLVQ(n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class GRLVQ with per-feature relevance weighting.

Learns which features are most important for distinguishing target from non-target data in a one-class setting.

Parameters:
  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

relevances_

Learned per-feature relevance weights (softmax-normalized).

Type:

array of shape (n_features,)

decision_function(X)[source]

Compute scores using relevance-weighted distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCGMLVQ(latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class GMLVQ with global Omega projection.

Learns a global linear projection \(\Omega\) that captures feature correlations for one-class classification.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omega_

Learned projection matrix.

Type:

array of shape (n_features, latent_dim)

decision_function(X)[source]

Compute scores using Omega-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCLGMLVQ(latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class LGMLVQ with per-prototype Omega projections.

Each prototype learns its own local metric via \(\Omega_k\), allowing different prototypes to attend to different feature subspaces.

Parameters:
  • latent_dim (int, optional) – Dimensionality of each projected space. Default: n_features.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype projection matrices.

Type:

array of shape (n_prototypes, n_features, latent_dim)

decision_function(X)[source]

Compute scores using per-prototype Omega-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCGTLVQ(subspace_dim=2, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class GTLVQ with per-prototype tangent subspaces.

Each prototype learns an orthonormal basis \(\Omega_k\) that defines directions of local invariance. Only the distance orthogonal to this subspace is used for classification.

Parameters:
  • subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype orthonormal tangent bases.

Type:

array of shape (n_prototypes, n_features, subspace_dim)

decision_function(X)[source]

Compute scores using tangent distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

One-Class GLVQ with Neural Gas

class prosemble.models.OCGLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class GLVQ with Neural Gas neighborhood cooperation.

All prototypes participate in the loss, weighted by their distance rank via exp(-rank / gamma). Uses squared Euclidean distance.

Parameters:
  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute target-likeness scores.

Scores near 1.0 indicate target class, near 0.0 indicate outlier. The decision boundary is at score = 0.5 (where \(d = \theta\)).

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCGRLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class GRLVQ with Neural Gas neighborhood cooperation.

Learns per-feature relevance weights with NG rank-weighted loss.

Parameters:
  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

relevances_

Learned per-feature relevance weights.

Type:

array of shape (n_features,)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute scores using relevance-weighted distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCGMLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class GMLVQ with Neural Gas neighborhood cooperation.

Learns a global Omega projection with NG rank-weighted loss.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omega_

Learned projection matrix.

Type:

array of shape (n_features, latent_dim)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute scores using Omega-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCLGMLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class LGMLVQ with Neural Gas neighborhood cooperation.

Learns per-prototype local Omega projections with NG rank-weighted loss.

Parameters:
  • latent_dim (int, optional) – Dimensionality of each projected space. Default: n_features.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype projection matrices.

Type:

array of shape (n_prototypes, n_features, latent_dim)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute scores using per-prototype Omega-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCGTLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class GTLVQ with Neural Gas neighborhood cooperation.

Learns per-prototype tangent subspaces with NG rank-weighted loss.

Parameters:
  • subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype orthonormal tangent bases.

Type:

array of shape (n_prototypes, n_features, subspace_dim)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute scores using tangent distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

One-Class RSLVQ

class prosemble.models.OCRSLVQ(sigma=1.0, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Robust Soft LVQ.

Combines one-class threshold detection with probabilistic soft-weighting of all prototypes via Gaussian mixture responsibilities.

All prototypes contribute to the one-class decision via Gaussian proximity weights, with standard Euclidean distances.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.

  • n_prototypes (int) – Number of prototypes for the target class. Default: 3.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

thetas_

Learned per-prototype acceptance thresholds.

Type:

array of shape (n_prototypes,)

decision_function(X)[source]

Compute target-likeness scores using soft-weighted distances.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCMRSLVQ(sigma=1.0, latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Matrix Robust Soft LVQ.

Combines one-class threshold detection with a learned global Omega projection matrix and probabilistic soft-weighting of all prototypes.

All prototypes contribute to the one-class decision via Gaussian proximity weights, with distances computed in the Omega-projected space.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.

  • latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.

  • n_prototypes (int) – Number of prototypes for the target class. Default: 3.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omega_

Learned projection matrix.

Type:

array of shape (n_features, latent_dim)

thetas_

Learned per-prototype visibility thresholds.

Type:

array of shape (n_prototypes,)

decision_function(X)[source]

Compute target-likeness scores using soft-weighted Omega distances.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCLMRSLVQ(sigma=1.0, latent_dim=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Localized Matrix Robust Soft LVQ.

Each prototype \(k\) has its own \(\Omega_k\) matrix for local metric adaptation, combined with probabilistic soft-weighting and one-class threshold detection.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.

  • latent_dim (int, optional) – Latent space dimensionality per prototype. Default: n_features.

  • n_prototypes (int) – Number of prototypes for the target class. Default: 3.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype projection matrices.

Type:

array of shape (n_prototypes, n_features, latent_dim)

thetas_

Learned per-prototype visibility thresholds.

Type:

array of shape (n_prototypes,)

decision_function(X)[source]

Compute target-likeness scores using local Omega distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

One-Class RSLVQ with Neural Gas

class prosemble.models.OCRSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class RSLVQ with Neural Gas neighborhood cooperation.

Combines soft Gaussian mixture responsibilities with NG rank-based cooperation. Uses standard Euclidean distances.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

thetas_

Learned per-prototype acceptance thresholds.

Type:

array of shape (n_prototypes,)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute target-likeness scores using combined Gaussian+NG weights.

Uses final (converged) gamma for NG modulation at inference time.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCMRSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class Matrix RSLVQ with Neural Gas neighborhood cooperation.

Learns a global Omega projection with combined Gaussian + NG rank-weighted loss for one-class classification.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.

  • latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omega_

Learned projection matrix.

Type:

array of shape (n_features, latent_dim)

thetas_

Learned per-prototype acceptance thresholds.

Type:

array of shape (n_prototypes,)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute target-likeness scores using combined Gaussian+NG weights.

Uses final (converged) gamma for NG modulation at inference time.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCLMRSLVQ_NG(sigma=1.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class Local Matrix RSLVQ with Neural Gas neighborhood cooperation.

Learns per-prototype \(\Omega_k\) projection matrices with combined Gaussian + NG rank-weighted loss for one-class classification.

Parameters:
  • sigma (float) – Bandwidth of Gaussian mixture for prototype weighting.

  • latent_dim (int, optional) – Latent space dimensionality per prototype. Default: n_features.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype projection matrices.

Type:

array of shape (n_prototypes, n_features, latent_dim)

thetas_

Learned per-prototype acceptance thresholds.

Type:

array of shape (n_prototypes,)

gamma_

Final gamma value after training.

Type:

float

decision_function(X)

Compute target-likeness scores using combined Gaussian+NG weights.

Uses final (converged) gamma for NG modulation at inference time.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores – Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

SVQ-OCC

class prosemble.models.SVQOCC(n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Vector Quantization One-Class Classification.

Combines Neural Gas representation learning with per-prototype visibility parameters \(\theta_k\) for one-class classification.

Parameters:
  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.

  • alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.

  • cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.

  • response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.

  • sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.

  • gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.

  • nu (float) – Degrees of freedom for Student-t response. Default: 1.0.

  • lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final NG neighborhood range. Default: 0.01.

  • lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

predict(X)[source]

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels – target_label for target, non_target_label for outliers.

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)[source]

Predict with a reject option for uncertain samples.

Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

decision_function(X)[source]

Compute summed responsibility scores.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.SVQOCC_R(n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Relevance-Weighted SVQ-OCC.

Extends SVQ-OCC with per-feature relevance weighting (like GRLVQ). Learns which features are most important for distinguishing target from non-target data.

Parameters:
  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.

  • alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.

  • cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.

  • response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.

  • sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.

  • gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.

  • nu (float) – Degrees of freedom for Student-t response. Default: 1.0.

  • lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final NG neighborhood range. Default: 0.01.

  • lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

relevances_

Learned per-feature relevance weights (softmax-normalized).

Type:

array of shape (n_features,)

See also

SVQOCC

Base SVQ-OCC model.

decision_function(X)[source]

Compute scores using relevance-weighted distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels – target_label for target, non_target_label for outliers.

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.SVQOCC_M(latent_dim=None, n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Matrix SVQ-OCC with global Omega projection.

Learns a global linear projection \(\Omega\) that captures feature correlations for one-class classification.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the projected space. Default: n_features.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.

  • alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.

  • cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.

  • response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.

  • sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.

  • gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.

  • nu (float) – Degrees of freedom for Student-t response. Default: 1.0.

  • lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final NG neighborhood range. Default: 0.01.

  • lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omega_

Learned projection matrix.

Type:

array of shape (n_features, latent_dim)

See also

SVQOCC

Base SVQ-OCC model.

decision_function(X)[source]

Compute scores using Omega-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels – target_label for target, non_target_label for outliers.

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.SVQOCC_LM(latent_dim=None, n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Local Matrix SVQ-OCC with per-prototype Omega projections.

Each prototype learns its own local metric via \(\Omega_k\), allowing different prototypes to attend to different feature subspaces.

Parameters:
  • latent_dim (int, optional) – Dimensionality of each projected space. Default: n_features.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.

  • alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.

  • cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.

  • response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.

  • sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.

  • gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.

  • nu (float) – Degrees of freedom for Student-t response. Default: 1.0.

  • lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final NG neighborhood range. Default: 0.01.

  • lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype projection matrices.

Type:

array of shape (n_prototypes, n_features, latent_dim)

See also

SVQOCC

Base SVQ-OCC model.

decision_function(X)[source]

Compute scores using per-prototype Omega-projected distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels – target_label for target, non_target_label for outliers.

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.SVQOCC_T(subspace_dim=2, n_prototypes=3, target_label=None, alpha=0.5, cost_function='contrastive', response_type='gaussian', sigma=0.1, gamma_resp=1.0, nu=1.0, lambda_init=None, lambda_final=0.01, lambda_decay=None, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Tangent SVQ-OCC with per-prototype tangent subspaces.

Each prototype learns an orthonormal basis \(\Omega_k\) that defines directions of local invariance. Only the distance orthogonal to this subspace is used for classification.

Parameters:
  • subspace_dim (int) – Dimensionality of each tangent subspace. Default: 2.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Which label is the target (normal) class. Default: auto-detect as the most frequent class.

  • alpha (float) – Balance between representation (R) and classification (C) cost. E = alpha * R + (1 - alpha) * C. Default: 0.5.

  • cost_function (str) – Classification cost variant: ‘contrastive’, ‘brier’, ‘cross_entropy’. Default: ‘contrastive’.

  • response_type (str) – Response probability model: ‘gaussian’, ‘student_t’, ‘uniform’. Default: ‘gaussian’.

  • sigma (float) – Sigmoid sharpness for differentiable Heaviside approximation. Smaller = sharper boundary. Default: 0.1.

  • gamma_resp (float) – Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.

  • nu (float) – Degrees of freedom for Student-t response. Default: 1.0.

  • lambda_init (float, optional) – Initial NG neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final NG neighborhood range. Default: 0.01.

  • lambda_decay (float, optional) – Per-step multiplicative decay for lambda. Default: computed from max_iter.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Distance function (default: squared Euclidean).

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping (no wasted compute after convergence, but slower per iteration).

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training. When set, each epoch iterates over shuffled mini-batches of this size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Supported strings: ‘exponential_decay’, ‘cosine_decay’, ‘warmup_cosine_decay’, ‘warmup_exponential_decay’, ‘warmup_constant’, ‘polynomial’, ‘linear’, ‘piecewise_constant’, ‘sgdr’. Or pass a custom optax.Schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments passed to the learning rate scheduler (e.g. decay_rate, transition_steps). Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Supported strings: ‘stratified_random’ (default), ‘class_mean’, ‘class_conditional_mean’, ‘stratified_noise’, ‘random_normal’, ‘uniform’, ‘zeros’, ‘ones’, ‘fill_value’. Or pass a callable (X, y, n_per_class, key) -> (protos, labels).

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. If None (default), stops after a single non-improving step (epsilon check). Requires use_scan=False for true early stopping.

  • restore_best (bool) – If True, restore the parameters that achieved the lowest loss (or validation loss if validation data is provided). Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Dict maps class label to weight, e.g. {0: 1.0, 1: 2.0, 2: 1.5}. ‘balanced’ auto-computes weights inversely proportional to class frequencies. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps before applying an update. Effective batch size = batch_size * gradient_accumulation_steps. Default: None (no accumulation).

  • ema_decay (float, optional) – Exponential moving average decay for parameters (0 < ema_decay < 1). After training, model parameters are replaced with EMA-smoothed values. Typical values: 0.999, 0.9999. Default: None (no EMA).

  • freeze_params (list of str, optional) – List of parameter group names to freeze (zero gradients). E.g. [‘backbone’] to freeze the backbone and only train prototypes. Default: None (all parameters trainable).

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Dict with keys: - ‘sync_period’: int (default 6) – sync every k steps - ‘slow_step_size’: float (default 0.5) – interpolation factor Default: None (no lookahead).

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. ‘float16’ or ‘bfloat16’. Master weights stay in float32; forward/backward pass runs in lower precision for ~2x speed and ~half memory on GPU. Float16 uses static loss scaling to prevent gradient underflow. Default: None (disabled).

omegas_

Learned per-prototype orthonormal tangent bases.

Type:

array of shape (n_prototypes, n_features, subspace_dim)

See also

SVQOCC

Base SVQ-OCC model.

decision_function(X)[source]

Compute scores using tangent distances.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels – target_label for target, non_target_label for outliers.

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores between lower and upper are rejected (labeled reject_label) instead of being forced into a class.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Scores in [lower, upper) are rejected. Default: same as upper (no rejection zone, equivalent to predict).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

Fuzzy Clustering

class prosemble.models.FCM(n_clusters, fuzzifier=2.0, max_iter=100, epsilon=1e-05, init_method='random', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

JAX implementation of Fuzzy C-Means clustering.

This implementation provides: - Full vectorization (no Python loops) - JIT compilation for speed - Automatic GPU acceleration - Immutable state management - Numerical stability

Key Differences from NumPy Version:

  1. Vectorization: All operations use matrix operations Old: Triple nested loops in centroid computation New: Single matrix multiplication

  2. Functional: Immutable state using NamedTuple Old: In-place updates (self.fit_cent = …) New: Return new state objects

  3. JIT Compilation: Functions compiled to machine code Old: Interpreted Python loops New: Compiled XLA code

  4. GPU Support: Automatic device placement Old: CPU-only NumPy New: GPU/TPU with JAX

param fuzzifier:

Fuzzification parameter (\(m\)). Must be > 1. - \(m = 1\): Hard clustering (crisp membership) - \(m = 2\): Standard fuzzy clustering - \(m \to \infty\): Maximum fuzziness (equal membership)

type fuzzifier:

float, default=2.0

param init_method:

Initialization method for \(U\) matrix: - ‘random’: Random Dirichlet distribution - ‘kmeans++’: K-means++ centroids then compute \(U\)

type init_method:

str, default=’random’

param n_clusters:

Number of clusters (must be >= 2).

type n_clusters:

int

param max_iter:

Maximum number of iterations.

type max_iter:

int

param epsilon:

Convergence threshold.

type epsilon:

float

param random_seed:

Random seed for reproducibility.

type random_seed:

int

param distance_fn:

Pairwise distance function. Default: squared Euclidean.

type distance_fn:

callable, optional

param patience:

Epochs with no improvement before early stopping. Default: None.

type patience:

int, optional

param restore_best:

If True, restore centroids from the lowest-objective epoch. Default: False.

type restore_best:

bool

param plot_steps:

Whether to visualize clustering progress. Default: False.

type plot_steps:

bool

param show_confidence:

Whether to show confidence in visualization. Default: True.

type show_confidence:

bool

param show_pca_variance:

Whether to show PCA variance in visualization. Default: True.

type show_pca_variance:

bool

param save_plot_path:

Path to save final plot.

type save_plot_path:

str, optional

param callbacks:

List of Callback objects for monitoring/visualization.

type callbacks:

list, optional

centroids_

Cluster centroids after fitting

Type:

array of shape (n_clusters, n_features)

U_

Fuzzy membership matrix after fitting

Type:

array of shape (n_samples, n_clusters)

objective_

Final objective function value

Type:

float

n_iter_

Number of iterations performed

Type:

int

history_

Training history containing objective values and other metrics

Type:

dict

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import FCM
>>>
>>> # Generate sample data
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])
>>>
>>> # Fit FCM model
>>> model = FCM(n_clusters=2, fuzzifier=2.0, max_iter=100)
>>> model.fit(X)
>>>
>>> # Get results
>>> labels = model.predict(X)
>>> centroids = model.final_centroids()
>>> membership = model.predict_proba(X)
>>>
>>> print(f"Labels: {labels}")
>>> print(f"Centroids shape: {centroids.shape}")
>>> print(f"Membership shape: {membership.shape}")

References

Bezdek, J. C. (1981). Pattern Recognition with Fuzzy Objective Function Algorithms. Plenum Press, New York.

Dunn, J. C. (1973). A Fuzzy Relative of the ISODATA Process and Its Use in Detecting Compact Well-Separated Clusters.

fit(X, initial_centroids=None, resume=False)[source]

Fit FCM model to data.

Parameters:
  • X (array-like, shape (n_samples, n_features)) – Training data

  • initial_centroids (array-like, shape (n_clusters, n_features), optional) – Pre-computed centroids for warm starting

  • resume (bool, default=False) – If True, resume from the model’s current fitted state

Return type:

Self

predict(X)[source]

Predict cluster labels for X.

Assigns each sample to the cluster with highest membership.

Parameters:

X (Array) – (n_samples, n_features) data

Returns:

(n_samples,) cluster assignments (0 to n_clusters-1)

Return type:

labels

Raises:

RuntimeError – If model not fitted yet

predict_proba(X)[source]

Predict fuzzy membership for X.

Parameters:

X (Array) – (n_samples, n_features) data

Returns:

(n_samples, n_clusters) fuzzy membership matrix

Each row sums to 1, values in [0, 1]

Return type:

U

Raises:

RuntimeError – If model not fitted yet

Parameters:
  • n_clusters (int)

  • fuzzifier (float)

  • max_iter (int)

  • epsilon (float)

  • init_method (str)

  • random_seed (int)

  • patience (int | None)

  • restore_best (bool)

  • plot_steps (bool)

  • show_confidence (bool)

  • show_pca_variance (bool)

  • save_plot_path (str)

class prosemble.models.PCM(n_clusters, fuzzifier=2.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=None, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

JAX-based Possibilistic C-Means clustering with GPU acceleration.

PCM is a clustering algorithm that assigns typicality values to data points, representing the degree to which they belong to each cluster. Unlike FCM, the typicality of a point to one cluster is independent of its typicality to other clusters.

Parameters:
  • fuzzifier (float, default=2.0) – Fuzzification parameter (\(m > 1\)). Higher values result in fuzzier clusters.

  • k (float, default=1.0) – Parameter for \(\gamma\) computation. Typical values are in [0.01, 1.0]. Lower values make the algorithm more sensitive to outliers.

  • init_method ({'fcm', 'random'}, default='fcm') – Initialization method: - ‘fcm’: Initialize using FCM results (recommended) - ‘random’: Random initialization

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Cluster centroids after fitting.

Type:

ndarray of shape (n_clusters, n_features)

T_

Typicality matrix after fitting.

Type:

ndarray of shape (n_samples, n_clusters)

gamma_

Scale parameters for each cluster.

Type:

ndarray of shape (n_clusters,)

n_iter_

Number of iterations run.

Type:

int

objective_

Final objective function value.

Type:

float

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import PCM
>>>
>>> # Generate sample data
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8], [1, 0.6], [9, 11]])
>>>
>>> # Fit PCM model
>>> model = PCM(n_clusters=2, fuzzifier=2.0, k=1.0)
>>> model.fit(X)
>>>
>>> # Get cluster assignments
>>> labels = model.predict(X)
>>>
>>> # Get typicality values
>>> typicalities = model.predict_proba(X)

Notes

  • PCM is less sensitive to outliers than FCM because typicality values are computed independently for each cluster.

  • The parameter \(k\) controls the sensitivity to outliers. Smaller values make the algorithm more sensitive.

  • Initialization from FCM (init_method=’fcm’) is recommended as it provides better starting points than random initialization.

  • All computations are JIT-compiled and can run on GPU if available.

fit(X, initial_centroids=None, resume=False)[source]

Fit PCM clustering model to data.

Parameters:
  • X (array-like, shape (n_samples, n_features)) – Training data

  • initial_centroids (array-like, shape (n_clusters, n_features), optional) – Pre-computed centroids for warm starting

  • resume (bool, default=False) – If True, resume from the model’s current fitted state

Return type:

Self

predict(X)[source]

Predict cluster labels for data.

Parameters:

X (Array | ndarray | bool | number) – Data matrix of shape (n_samples, n_features)

Returns:

Cluster labels of shape (n_samples,)

Return type:

labels

Raises:

ValueError – If model has not been fitted

predict_proba(X)[source]

Predict typicality values for data.

Parameters:

X (Array | ndarray | bool | number) – Data matrix of shape (n_samples, n_features)

Returns:

Typicality matrix of shape (n_samples, n_clusters)

Return type:

T

Raises:

ValueError – If model has not been fitted

class prosemble.models.FPCM(n_clusters, fuzzifier=2.0, eta=2.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Fuzzy Possibilistic C-Means clustering with JAX.

FPCM maintains TWO matrices: \(U\) (fuzzy membership) and \(T\) (typicality). \(U\) has row-sum-to-1 constraint (standard FCM), while \(T\) has column-sum-to-1 constraint per the original Pal, Pal & Bezdek (1997) formulation.

Algorithm:

  1. Initialize \(U\) and \(T\) (randomly or using FCM)

  2. Update centroids using combined fuzzy and typicality weights

  3. Update \(U\) using FCM rule with fuzzifier \(m\) (row-normalized)

  4. Update \(T\) with column-normalization

  5. Repeat until convergence

Objective function:

\[J = \sum_i \sum_j \left[u_{ij}^m + t_{ij}^\eta\right] \|x_i - v_j\|^2\]
Reference:

Pal, N. R., Pal, K., & Bezdek, J. C. (1997). A mixed c-means clustering model. FUZZ-IEEE.

Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).

  • eta (float, default=2.0) – Fuzziness parameter for \(T\) matrix (must be > 1.0).

  • init_method ({'random', 'fcm'}, default='fcm') – Method for initializing \(U\) and \(T\) matrices.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

U_

Final fuzzy membership matrix

Type:

array, shape (n_samples, n_clusters)

T_

Final possibilistic typicality matrix

Type:

array, shape (n_samples, n_clusters)

n_iter_

Number of iterations until convergence

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective value at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import FPCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = FPCM(n_clusters=2, fuzzifier=2.0, eta=2.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
>>> U = model.predict_proba(X)
>>> T = model.get_typicality(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit FPCM model to data.

Parameters:
  • X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

  • initial_centroids – Optional initial centroids for warm starting

  • resume – If True, resume from fitted state

Returns:

Fitted model

Return type:

self

Raises:

ValueError – If n_samples < n_clusters

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

Returns:

Cluster labels, shape (n_samples,)

Return type:

labels

Raises:

ValueError – If model has not been fitted

predict_proba(X)[source]

Predict fuzzy membership probabilities (U matrix).

Parameters:

X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

Returns:

Fuzzy membership matrix, shape (n_samples, n_clusters)

Return type:

U

Raises:

ValueError – If model has not been fitted

get_typicality(X)[source]

Compute typicality values (T matrix).

Parameters:

X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

Returns:

Typicality matrix, shape (n_samples, n_clusters)

Return type:

T

Raises:

ValueError – If model has not been fitted

class prosemble.models.PFCM(n_clusters, fuzzifier=2.0, eta=2.0, a=1.0, b=1.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

JAX implementation of Possibilistic Fuzzy C-Means clustering.

PFCM combines fuzzy membership and typicality for robust clustering.

Parameters:
  • fuzzifier (float, default=2.0) – Fuzzification parameter for membership (\(m\)). Must be > 1.

  • eta (float, default=2.0) – Fuzzification parameter for typicality (\(\eta\)). Must be > 1.

  • a (float, default=1.0) – Weight for fuzzy membership term. Must be >= 0.

  • b (float, default=1.0) – Weight for typicality term. Must be >= 0.

  • k (float, default=1.0) – Parameter for \(\gamma\) computation. Must be > 0.

  • init_method (str, default='fcm') – Initialization method: ‘fcm’ or ‘random’.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array

U_

Final fuzzy membership matrix

Type:

array

T_

Final typicality matrix

Type:

array

gamma_

Final scale parameters

Type:

array

objective_

Final objective value

Type:

float

n_iter_

Number of iterations performed

Type:

int

fit(X, initial_centroids=None, resume=False)[source]

Fit PFCM model to data.

Parameters:
  • X (array-like, shape (n_samples, n_features)) – Training data

  • initial_centroids (array-like, shape (n_clusters, n_features), optional) – Pre-computed centroids for warm starting

  • resume (bool, default=False) – If True, resume from the model’s current fitted state

Return type:

Self

predict(X)[source]

Predict cluster labels.

Parameters:

X (Array)

Return type:

Array

predict_proba(X)[source]

Predict fuzzy membership.

Parameters:

X (Array)

Return type:

Array

predict_typicality(X)[source]

Predict typicality values.

Parameters:

X (Array)

Return type:

Array

class prosemble.models.AFCM(n_clusters, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Adaptive Fuzzy C-Means clustering with JAX.

AFCM is an adaptive variant that combines fuzzy and possibilistic approaches with specific parameter combinations.

Key features: - Centroids use \(a \cdot U^m + b \cdot T\) (\(T\) to power 1, not \(m\)!) - \(\gamma\) computed with Euclidean distance (not squared) - Exponential \(T\) update with parameter \(b\) - Standard FCM \(U\) update

Algorithm:

  1. Initialize \(U\) using FCM

  2. Compute \(\gamma\) parameters using Euclidean distance

  3. Update \(T\) using exponential update

  4. Update \(U\) using standard FCM rule

  5. Update centroids using combined fuzzy-possibilistic weights

  6. Repeat until convergence

Objective function:

\[J = \sum_i \sum_j \left[d_{ij}^2 \cdot (a \cdot u_{ij}^m + b \cdot t_{ij})\right] + \sum_j \left[\gamma_j \cdot \sum_i (t_{ij} \log t_{ij} - t_{ij})\right]\]
Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).

  • a (float, default=1.0) – Weight for fuzzy membership term (must be > 0).

  • b (float, default=1.0) – Weight for typicality term (must be > 0).

  • k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).

  • init_method ({'fcm'}, default='fcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

U_

Final fuzzy membership matrix

Type:

array, shape (n_samples, n_clusters)

T_

Final typicality matrix

Type:

array, shape (n_samples, n_clusters)

gamma_

Final scale parameters

Type:

array, shape (n_clusters,)

n_iter_

Number of iterations until convergence

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective values at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import AFCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = AFCM(n_clusters=2, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit AFCM model to data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]

Predict fuzzy membership probabilities.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.HCM(n_clusters, max_iter=100, epsilon=1e-05, init_method='random', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Hard C-Means (K-Means) clustering with JAX.

HCM assigns each data point to exactly one cluster (hard assignment) based on the nearest centroid. This is the classic K-Means algorithm.

Algorithm:

  1. Initialize centroids randomly or from data

  2. Assign each point to nearest centroid

  3. Update centroids as mean of assigned points

  4. Repeat until convergence

Objective function:

\[J = \sum_i \|x_i - v_{l_i}\|^2\]
Parameters:
  • init_method ({'random', 'kmeans++'}, default='random') – Method for initializing centroids.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

labels_

Hard cluster assignments for training data

Type:

array, shape (n_samples,)

n_iter_

Number of iterations until convergence

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective value at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import HCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = HCM(n_clusters=2, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit HCM model to data.

Parameters:
  • X (array-like, shape (n_samples, n_features)) – Training data

  • initial_centroids (array-like, shape (n_clusters, n_features), optional) – Pre-computed centroids for warm starting

  • resume (bool, default=False) – If True, resume from the model’s current fitted state

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

Returns:

Cluster labels, shape (n_samples,)

Return type:

labels

Raises:

ValueError – If model has not been fitted

class prosemble.models.IPCM(n_clusters, fuzzifier=2.0, tipifier=2.0, k=1.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Improved Possibilistic C-Means clustering with JAX.

IPCM uses a two-phase approach to improve clustering performance: - Phase 0: Initialize \(\gamma\) using fuzzy membership only - Phase 1: Refine \(\gamma\) using both membership and typicality

Key differences from PCM: - Uses product of \(U^{m_f}\) and \(T^{m_p}\) in centroid computation - Modified \(U\) update that depends on \(T\) - Two-phase \(\gamma\) computation

Algorithm (Phase 0):

  1. Initialize \(U\) using FCM, \(T = 0\)

  2. Compute \(\gamma\) parameters from fuzzy membership

  3. Update typicality matrix \(T\)

  4. Update membership matrix \(U\)

  5. Update centroids using combined U and T weights

  6. Repeat until convergence

Algorithm (Phase 1):

  1. Recompute \(\gamma\) using both \(U\) and \(T\)

  2. Continue iterations with new gamma

Objective function:

\[J = \sum_i \sum_j u_{ij}^{m_f} \cdot t_{ij}^{m_p} \cdot d_{ij}^2 + \sum_j \gamma_j \sum_i (1 - t_{ij})^{m_p} \cdot u_{ij}^{m_f}\]
Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (\(m_f\), must be > 1.0).

  • tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (\(m_p\), must be > 1.0).

  • k (float, default=1.0) – Scaling parameter for \(\gamma\) in phase 1 (must be > 0).

  • init_method ({'fcm'}, default='fcm') – Method for initializing \(U\) matrix (must use FCM).

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

U_

Final fuzzy membership matrix

Type:

array, shape (n_samples, n_clusters)

T_

Final typicality matrix

Type:

array, shape (n_samples, n_clusters)

gamma_

Final scale parameters

Type:

array, shape (n_clusters,)

n_iter_

Total number of iterations across both phases

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective value at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import IPCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = IPCM(n_clusters=2, fuzzifier=2.0, tipifier=2.0, k=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit IPCM model to data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

Returns:

Cluster labels, shape (n_samples,)

Return type:

labels

Raises:

ValueError – If model has not been fitted

predict_proba(X)[source]

Predict fuzzy membership probabilities (U matrix).

Parameters:

X (Array | ndarray | bool | number) – Input data, shape (n_samples, n_features)

Returns:

Fuzzy membership matrix, shape (n_samples, n_clusters)

Return type:

U

Raises:

ValueError – If model has not been fitted

class prosemble.models.IPCM2(n_clusters, fuzzifier=2.0, tipifier=2.0, max_iter=100, epsilon=1e-05, init_method='fcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Improved Possibilistic C-Means 2 clustering with JAX.

IPCM2 is a variant of IPCM with key differences: - Uses exponential \(T\) update: \(t_{ij} = \exp(-d_{ij}^2 / \gamma_j)\) - Centroids use \(U^{m_f} \cdot T\) (\(T\) without power!) - Modified \(U\) update with exponential distance - Different objective function

Algorithm (Phase 0):

  1. Initialize \(U\) using FCM, \(T = 0\)

  2. Compute \(\gamma\) parameters from fuzzy membership

  3. Update \(T\) using exponential update

  4. Update \(U\) with modified distance

  5. Update centroids using combined U and T weights

  6. Repeat until convergence

Algorithm (Phase 1):

  1. Recompute \(\gamma\) using both \(U\) and \(T\)

  2. Continue iterations with new gamma

Objective function:

\[J = \sum_i \sum_j u_{ij}^{m_f} \cdot t_{ij} \cdot d_{ij}^2 + \sum_j \gamma_j \sum_i (t_{ij} \log t_{ij} - t_{ij} + 1) \cdot u_{ij}^{m_f}\]
Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (\(m_f\), must be > 1.0).

  • tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (\(m_p\), must be > 1.0).

  • init_method ({'fcm'}, default='fcm') – Method for initializing \(U\) matrix.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

U_

Final fuzzy membership matrix

Type:

array, shape (n_samples, n_clusters)

T_

Final typicality matrix

Type:

array, shape (n_samples, n_clusters)

gamma_

Final scale parameters

Type:

array, shape (n_clusters,)

n_iter_

Total number of iterations

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective values at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import IPCM2
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = IPCM2(n_clusters=2, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit IPCM2 model to data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]

Predict fuzzy membership probabilities.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

Kernel Clustering

class prosemble.models.KFCM(n_clusters, fuzzifier=2.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='random', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Fuzzy C-Means clustering with JAX.

KFCM uses a Gaussian kernel to map data into a high-dimensional feature space where clustering is performed. This allows handling non-linearly separable data.

Kernel:

\[K(x, y) = \exp\left(-\frac{\|x - y\|^2}{\sigma^2}\right)\]

Kernel distance in feature space:

\[\|\varphi(x) - \varphi(y)\|^2 = 2(1 - K(x, y))\]

Algorithm:

  1. Initialize \(U\) randomly

  2. Update centroids (kernel-weighted)

  3. Update \(U\) using kernel distance

  4. Repeat until convergence

Objective function:

\[J = 2 \sum_i \sum_j u_{ij}^m (1 - K(x_i, v_j))\]
Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'random'}, default='random') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

U_

Final fuzzy membership matrix

Type:

array, shape (n_samples, n_clusters)

n_iter_

Number of iterations until convergence

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective values at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import KFCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = KFCM(n_clusters=2, sigma=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit KFCM model to data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]

Predict fuzzy membership probabilities.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.KPCM(n_clusters, fuzzifier=2.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Possibilistic C-Means clustering with JAX.

KPCM extends PCM to kernel space using Gaussian kernel, allowing handling of non-linearly separable data while maintaining possibilistic properties.

Kernel:

\[K(x, y) = \exp\left(-\frac{\|x - y\|^2}{\sigma^2}\right)\]

Kernel distance:

\[d_K(x, v) = 2(1 - K(x, v))\]

Algorithm:

  1. Initialize using KFCM

  2. Compute \(\gamma\) parameters

  3. Update typicality matrix \(T\)

  4. Update centroids (kernel-weighted)

  5. Repeat until convergence

Objective function:

\[J = \sum_i \sum_j t_{ij}^m \cdot d_K(x_i, v_j) + \sum_j \gamma_j \sum_i (1 - t_{ij})^m\]
Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).

  • k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'kfcm'}, default='kfcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

centroids_

Final cluster centroids

Type:

array, shape (n_clusters, n_features)

T_

Final typicality matrix

Type:

array, shape (n_samples, n_clusters)

gamma_

Final scale parameters

Type:

array, shape (n_clusters,)

n_iter_

Number of iterations until convergence

Type:

int

objective_

Final objective function value

Type:

float

objective_history_

Objective values at each iteration

Type:

array

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import KPCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = KPCM(n_clusters=2, sigma=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit KPCM model to data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]

Predict typicality values.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.KFPCM(n_clusters, fuzzifier=2.0, eta=2.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Fuzzy Possibilistic C-Means with JAX.

KFPCM maintains two matrices (\(U\) and \(T\)) in kernel space. \(U\) has row-sum-to-1 constraint (standard FCM), while \(T\) has column-sum-to-1 constraint per the original Pal, Pal & Bezdek (1997) FPCM formulation.

Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).

  • eta (float, default=2.0) – Fuzziness parameter for \(T\) matrix (must be > 1.0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'kfcm', 'random'}, default='kfcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

fit(X, initial_centroids=None, resume=False)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

get_typicality(X)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.KPFCM(n_clusters, fuzzifier=2.0, eta=2.0, a=1.0, b=1.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Possibilistic Fuzzy C-Means with JAX.

KPFCM combines fuzzy membership (\(U\)) and typicality (\(T\)) in kernel space with weights \(a\) and \(b\).

Parameters:
  • fuzzifier (float, default=2.0) – Fuzzification parameter for membership (must be > 1.0).

  • eta (float, default=2.0) – Fuzzification parameter for typicality (must be > 1.0).

  • a (float, default=1.0) – Weight for fuzzy membership term (must be > 0).

  • b (float, default=1.0) – Weight for typicality term (must be > 0).

  • k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'kfcm'}, default='kfcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

fit(X, initial_centroids=None, resume=False)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

get_typicality(X)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.KAFCM(n_clusters, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Adaptive Fuzzy C-Means clustering with JAX.

KAFCM extends AFCM to kernel space, combining fuzzy and possibilistic approaches with kernel-based distance measures.

Kernel distance: \(d_K(x, v) = 2(1 - K(x, v))\)

Algorithm:

  1. Initialize \(U\) using KFCM

  2. Compute \(\gamma\) parameters using kernel distance

  3. Update \(T\) using exponential kernel update

  4. Update \(U\) using standard KFCM rule

  5. Update centroids (kernel-weighted with combined weights)

  6. Repeat until convergence

Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter (must be > 1.0).

  • a (float, default=1.0) – Weight for fuzzy membership (must be > 0).

  • b (float, default=1.0) – Weight for typicality (must be > 0).

  • k (float, default=1.0) – Scaling parameter for \(\gamma\) (must be > 0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'kfcm'}, default='kfcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

Examples

>>> import jax.numpy as jnp
>>> from prosemble.models import KAFCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = KAFCM(n_clusters=2, sigma=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
fit(X, initial_centroids=None, resume=False)[source]

Fit model.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict labels.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]

Predict probabilities.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.KIPCM(n_clusters, fuzzifier=2.0, tipifier=2.0, k=1.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Improved Possibilistic C-Means with JAX.

KIPCM uses two-phase approach in kernel space with product-based centroids.

Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).

  • tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (must be > 1.0).

  • k (float, default=1.0) – Scaling parameter for \(\gamma\) in phase 1 (must be > 0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'kfcm'}, default='kfcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

fit(X, initial_centroids=None, resume=False)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

class prosemble.models.KIPCM2(n_clusters, fuzzifier=2.0, tipifier=2.0, sigma=1.0, max_iter=100, epsilon=1e-05, init_method='kfcm', random_seed=42, distance_fn=None, patience=None, restore_best=False, plot_steps=False, show_confidence=True, show_pca_variance=True, save_plot_path=None, callbacks=None)[source]

Kernel Improved Possibilistic C-Means 2 with JAX.

KIPCM2 uses exponential \(T\) update and modified objective in kernel space.

Parameters:
  • fuzzifier (float, default=2.0) – Fuzziness parameter for \(U\) matrix (must be > 1.0).

  • tipifier (float, default=2.0) – Possibilistic parameter for \(T\) matrix (must be > 1.0).

  • sigma (float, default=1.0) – Kernel bandwidth parameter (must be > 0).

  • init_method ({'kfcm'}, default='kfcm') – Initialization method.

  • n_clusters (int) – Number of clusters (must be >= 2).

  • max_iter (int) – Maximum number of iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed for reproducibility.

  • distance_fn (callable, optional) – Pairwise distance function. Default: squared Euclidean.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore centroids from the lowest-objective epoch. Default: False.

  • plot_steps (bool) – Whether to visualize clustering progress. Default: False.

  • show_confidence (bool) – Whether to show confidence in visualization. Default: True.

  • show_pca_variance (bool) – Whether to show PCA variance in visualization. Default: True.

  • save_plot_path (str, optional) – Path to save final plot.

  • callbacks (list, optional) – List of Callback objects for monitoring/visualization.

fit(X, initial_centroids=None, resume=False)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Self

predict(X)[source]

Predict cluster labels for new data.

Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

predict_proba(X)[source]
Parameters:

X (Array | ndarray | bool | number)

Return type:

Array | ndarray | bool | number

One-Class Differentiating Kernel

class prosemble.models.OCDKGLVQ(sigma_init='median', sigma_min=0.001, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Differentiating Kernel GLVQ.

Combines OC-GLVQ with Gaussian kernel distance and learnable per-prototype bandwidths \(\sigma_k\).

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|x - w_k\|^2}{2\sigma_k^2} \right)\right)\]

Kernel distances are bounded in \([0, 2]\), so thresholds \(\theta_k\) are initialized in kernel distance scale.

Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-prototype median distance from prototype to target class members. ‘mean’: per-prototype mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

thetas_

Learned per-prototype visibility thresholds in kernel distance scale.

Type:

array of shape (n_prototypes,)

sigmas_

Learned per-prototype kernel bandwidths.

Type:

array of shape (n_prototypes,)

References

See also

OCGLVQ

Base class with Euclidean distance.

DKGLVQ

Supervised variant with kernel distance.

decision_function(X)[source]

Compute target-likeness scores using kernel distance.

Scores near 1.0 indicate target class, near 0.0 indicate outlier. The decision boundary is at score = 0.5 (where \(d_\kappa = \theta\)).

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

property kernel_bandwidths

Return the learned per-prototype bandwidths.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCDKGRLVQ(sigma_init='median', sigma_min=0.001, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Differentiating Kernel GRLVQ.

Combines OC-GLVQ with per-feature relevance weighting and Gaussian kernel distance with per-prototype bandwidth adaptation.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2} \right)\right)\]

where \(\lambda = \text{softmax}(\text{relevances})\) are learned per-feature weights.

Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-prototype median distance from prototype to target class members. ‘mean’: per-prototype mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

thetas_

Learned per-prototype visibility thresholds in kernel distance scale.

Type:

array of shape (n_prototypes,)

sigmas_

Learned per-prototype kernel bandwidths.

Type:

array of shape (n_prototypes,)

relevances_

Learned per-feature relevance weights (raw logits).

Type:

array of shape (n_features,)

References

See also

OCGLVQ

Base class with Euclidean distance.

DKGRLVQ

Supervised variant with kernel distance and relevances.

decision_function(X)[source]

Compute target-likeness scores using relevance-weighted kernel distance.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

property kernel_bandwidths

Return the learned per-prototype bandwidths.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCDKGMLVQ(latent_dim=None, omega_hat_scale=0.1, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

One-Class Differentiating Kernel GMLVQ.

Combines OC-GLVQ with exponential kernel distance and a learnable global transformation matrix \(\hat\Omega\) (d x latent_dim).

\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]

where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

Note: \(\kappa(v, v) \neq 1\) for the exponential kernel, so distances are not bounded in \([0, 2]\).

Parameters:
  • latent_dim (int, optional) – Dimensionality of the transformation. If None, uses input dim.

  • omega_hat_scale (float) – Scale factor for omega_hat initialization. Default: 0.1. Smaller values prevent exp overflow at initialization.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

thetas_

Learned per-prototype visibility thresholds in kernel distance scale.

Type:

array of shape (n_prototypes,)

omega_hat_

Learned transformation matrix.

Type:

array of shape (n_features, latent_dim)

References

See also

OCGLVQ

Base class with Euclidean distance.

DKGMLVQ

Supervised variant with exponential kernel distance.

decision_function(X)[source]

Compute target-likeness scores using exponential kernel distance.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

property omega_hat_matrix

Return the learned \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

One-Class Differentiating Kernel (Neural Gas)

class prosemble.models.OCDKGLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class Differentiating Kernel GLVQ with Neural Gas cooperation.

Combines OC-DKGLVQ (Gaussian kernel distance with learnable bandwidths) with NG rank-weighted loss where all prototypes participate.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|x - w_k\|^2}{2\sigma_k^2} \right)\right)\]
Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-prototype median distance from prototype to target class members. ‘mean’: per-prototype mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

thetas_

Learned per-prototype visibility thresholds in kernel distance scale.

Type:

array of shape (n_prototypes,)

sigmas_

Learned per-prototype kernel bandwidths.

Type:

array of shape (n_prototypes,)

gamma_

Final gamma value after training.

Type:

float

References

See also

OCDKGLVQ

Base class without NG cooperation.

OCGLVQ_NG

NG variant with Euclidean distance.

decision_function(X)

Compute target-likeness scores using kernel distance.

Scores near 1.0 indicate target class, near 0.0 indicate outlier. The decision boundary is at score = 0.5 (where \(d_\kappa = \theta\)).

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

class prosemble.models.OCDKGRLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class Differentiating Kernel GRLVQ with Neural Gas cooperation.

Combines OC-DKGRLVQ (per-feature relevance weighting + Gaussian kernel distance) with NG rank-weighted loss.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2} \right)\right)\]

where \(\lambda = \text{softmax}(\text{relevances})\).

Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-prototype median distance from prototype to target class members. ‘mean’: per-prototype mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

thetas_

Learned per-prototype visibility thresholds in kernel distance scale.

Type:

array of shape (n_prototypes,)

sigmas_

Learned per-prototype kernel bandwidths.

Type:

array of shape (n_prototypes,)

relevances_

Learned per-feature relevance weights (raw logits).

Type:

array of shape (n_features,)

gamma_

Final gamma value after training.

Type:

float

References

See also

OCDKGRLVQ

Base class without NG cooperation.

OCGRLVQ_NG

NG variant with Euclidean distance and relevances.

decision_function(X)

Compute target-likeness scores using relevance-weighted kernel distance.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

property relevance_profile

Return the learned relevance weights (normalized via softmax).

class prosemble.models.OCDKGMLVQ_NG(gamma_init=None, gamma_final=0.01, gamma_decay=None, n_prototypes=3, target_label=None, beta=10.0, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

One-Class Differentiating Kernel GMLVQ with Neural Gas cooperation.

Combines OC-DKGMLVQ (exponential kernel distance with learnable \(\hat\Omega\) matrix) with NG rank-weighted loss.

\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]

where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

Parameters:
  • latent_dim (int, optional) – Dimensionality of the transformation. If None, uses input dim.

  • omega_hat_scale (float) – Scale factor for omega_hat initialization. Default: 0.1.

  • gamma_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes (int) – Number of prototypes for the target class.

  • target_label (int, optional) – Target (normal) class label. Default: auto-detect.

  • beta (float) – Sigmoid steepness. Default: 10.0.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

thetas_

Learned per-prototype visibility thresholds in kernel distance scale.

Type:

array of shape (n_prototypes,)

omega_hat_

Learned transformation matrix.

Type:

array of shape (n_features, latent_dim)

gamma_

Final gamma value after training.

Type:

float

References

See also

OCDKGMLVQ

Base class without NG cooperation.

OCGMLVQ_NG

NG variant with Euclidean distance and Omega projection.

decision_function(X)

Compute target-likeness scores using exponential kernel distance.

Scores near 1.0 indicate target class, near 0.0 indicate outlier.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

scores

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

property omega_hat_matrix

Return the learned \(\hat\Omega\) matrix.

predict(X)

Predict target or non-target labels.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

labels

Return type:

array of shape (n_samples,)

predict_with_reject(X, upper=0.5, lower=None, reject_label=-1)

Predict with a reject option for uncertain samples.

Samples with scores in [lower, upper) are rejected.

Parameters:
  • X (array-like of shape (n_samples, n_features))

  • upper (float) – Scores >= upper are classified as target. Default: 0.5.

  • lower (float, optional) – Scores < lower are classified as non-target. Default: same as upper (no rejection zone).

  • reject_label (int) – Label for rejected samples. Default: -1.

Returns:

labels

Return type:

array of shape (n_samples,)

Differentiating Kernel Models

Supervised

class prosemble.models.DKGLVQ(sigma_init='median', sigma_min=0.001, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Differentiating Kernel GLVQ.

GLVQ with Gaussian kernel distance in feature space. Each prototype \(w_k\) has a learnable bandwidth \(\sigma_k\) that is adapted via gradient descent alongside the prototypes.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|x - w_k\|^2}{2\sigma_k^2} \right)\right)\]
Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median distance from prototype to class members. ‘mean’: per-class mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • beta (float) – Transfer function steepness parameter.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

References

See also

SupervisedPrototypeModel

Full list of base parameters.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)[source]

Predict using learned kernel distance.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.DKGRLVQ(sigma_init='median', sigma_min=0.001, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Differentiating Kernel GRLVQ.

Combines GRLVQ per-feature relevance weighting with Gaussian kernel distance and per-prototype bandwidth adaptation.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2} \right)\right)\]
Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median distance from prototype to class members. ‘mean’: per-class mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • beta (float) – Transfer function steepness parameter.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

References

See also

SupervisedPrototypeModel

Full list of base parameters.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)[source]

Predict using learned kernel distance with relevance weighting.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.DKGMLVQ(latent_dim=None, omega_hat_scale=0.1, beta=10.0, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None, **kwargs)[source]

Differentiating Kernel GMLVQ with Exponential Kernel.

Learns a global transformation matrix \(\hat\Omega\) (d x latent_dim) such that distances are computed via the exponential kernel:

\[\kappa_{\exp}(x, w) = \exp(x^T \hat\Lambda w), \quad \hat\Lambda = \hat\Omega \hat\Omega^T\]
\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]

Note: \(\kappa(v, v) \neq 1\) for the exponential kernel, so the full three-term distance formula is used.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the transformation. If None, uses input dim.

  • omega_hat_scale (float) – Scale factor for omega_hat initialization. Default: 0.1. Smaller values prevent exp overflow at initialization.

  • beta (float) – Transfer function steepness.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

References

See also

SupervisedPrototypeModel

Full list of base parameters.

property omega_hat_matrix

Return the learned \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

predict(X)[source]

Predict using learned exponential kernel distance.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

Supervised (Neural Gas)

class prosemble.models.DKGLVQ_NG(sigma_init='median', sigma_min=0.001, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Differentiating Kernel GLVQ with Neural Gas cooperation.

Combines Gaussian kernel distance with per-prototype bandwidth adaptation and Neural Gas rank-weighted loss. All same-class prototypes participate, weighted by \(\exp(-\text{rank} / \gamma)\).

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|x - w_k\|^2}{2\sigma_k^2} \right)\right)\]

When \(\gamma \to 0\), DKGLVQ-NG recovers standard DKGLVQ.

Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median distance from prototype to class members. ‘mean’: per-class mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • beta (float) – Transfer function steepness. Default: 10.0.

  • gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

sigmas_

Learned per-prototype kernel bandwidths.

Type:

array of shape (n_prototypes,)

gamma_

Final gamma value after training.

Type:

float

References

See also

DKGLVQ

Base variant without NG cooperation.

SRNG

NG variant with Euclidean relevance-weighted distance.

predict(X)[source]

Predict using learned kernel distance.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.DKGRLVQ_NG(sigma_init='median', sigma_min=0.001, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Differentiating Kernel GRLVQ with Neural Gas cooperation.

Combines relevance-weighted Gaussian kernel distance with per-prototype bandwidth adaptation and Neural Gas rank-weighted loss.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2} \right)\right)\]

where \(\lambda = \text{softmax}(\text{relevances})\).

When \(\gamma \to 0\), DKGRLVQ-NG recovers standard DKGRLVQ.

Parameters:
  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median distance. ‘mean’: per-class mean distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • beta (float) – Transfer function steepness. Default: 10.0.

  • gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

sigmas_

Learned per-prototype kernel bandwidths.

Type:

array of shape (n_prototypes,)

relevances_

Learned per-feature relevance weights (raw logits).

Type:

array of shape (n_features,)

gamma_

Final gamma value after training.

Type:

float

References

See also

DKGRLVQ

Base variant without NG cooperation.

SRNG

NG variant with Euclidean relevance-weighted distance.

predict(X)[source]

Predict using learned relevance-weighted kernel distance.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

class prosemble.models.DKGMLVQ_NG(latent_dim=None, omega_hat_scale=0.1, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=True, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Differentiating Kernel GMLVQ with Neural Gas cooperation.

Combines exponential kernel distance with a learnable transformation matrix \(\hat\Omega\) and Neural Gas rank-weighted loss.

\[d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x) + \exp(w^T \hat\Lambda w) - 2 \exp(x^T \hat\Lambda w)\]

where \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

When \(\gamma \to 0\), DKGMLVQ-NG recovers standard DKGMLVQ.

Parameters:
  • latent_dim (int, optional) – Dimensionality of the transformation. If None, uses input dim.

  • omega_hat_scale (float) – Scale factor for omega_hat initialization. Default: 0.1.

  • beta (float) – Transfer function steepness. Default: 10.0.

  • gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay for gamma.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • batch_size (int, optional) – Mini-batch size. If None, use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – If True, restore best parameters after training.

  • class_weight (dict or 'balanced', optional) – Weights for each class.

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps.

  • ema_decay (float, optional) – Exponential moving average decay for parameters.

  • freeze_params (list of str, optional) – Parameter group names to freeze.

  • lookahead (dict, optional) – Lookahead optimizer wrapper configuration.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training.

omega_hat_

Learned transformation matrix.

Type:

array of shape (n_features, latent_dim)

gamma_

Final gamma value after training.

Type:

float

References

See also

DKGMLVQ

Base variant without NG cooperation.

SMNG

NG variant with Euclidean Omega-projected distance.

predict(X)[source]

Predict using learned exponential kernel distance.

property omega_hat_matrix

Return the learned \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

predict_proba(X)

Predict class probabilities via softmin of distances.

Parameters:

X (array-like of shape (n_samples, n_features))

Returns:

proba

Return type:

array of shape (n_samples, n_classes)

Unsupervised

class prosemble.models.DKNeuralGas(n_prototypes, kernel_sigma=1.0, lr_init=0.5, lr_final=0.01, lambda_init=None, lambda_final=0.01, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Differentiating Kernel Neural Gas.

Neural Gas with Gaussian kernel distance for ranking. The kernel bandwidth \(\sigma\) is a fixed hyperparameter (not learned). The competitive Hebbian update rule operates in the original data space — only the distance metric changes.

Parameters:
  • kernel_sigma (float) – Gaussian kernel bandwidth. Default: 1.0.

  • n_prototypes (int) – Number of prototypes/nodes.

  • lr_init (float) – Initial learning rate.

  • lr_final (float) – Final learning rate.

  • lambda_init (float, optional) – Initial neighborhood range.

  • lambda_final (float) – Final neighborhood range.

  • max_iter (int) – Maximum training iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • patience (int, optional) – Epochs with no improvement before early stopping.

  • restore_best (bool) – If True, restore best parameters after training.

References

See also

NeuralGas

Base class with Euclidean distance.

fit(X)

Fit Neural Gas.

predict(X)

Assign each sample to closest prototype (BMU).

transform(X)

Return distance matrix to all prototypes.

class prosemble.models.DKKohonenSOM(grid_height=10, grid_width=10, kernel_sigma=1.0, sigma_init=None, sigma_final=0.5, lr_init=0.5, lr_final=0.01, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Differentiating Kernel Kohonen SOM.

Standard Kohonen SOM with Gaussian kernel distance for BMU selection. The kernel bandwidth \(\sigma\) is a fixed hyperparameter (not learned). The grid-based neighborhood and competitive update rule operate in the original data space — only the data-space distance metric changes.

Parameters:
  • kernel_sigma (float) – Gaussian kernel bandwidth for data-space distance. Default: 1.0.

  • grid_height (int) – Height of the 2D grid.

  • grid_width (int) – Width of the 2D grid.

  • sigma_init (float, optional) – Initial grid neighborhood radius.

  • sigma_final (float) – Final grid neighborhood radius.

  • lr_init (float) – Initial learning rate.

  • lr_final (float) – Final learning rate.

  • max_iter (int) – Maximum training iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • patience (int, optional) – Epochs with no improvement before early stopping.

  • restore_best (bool) – If True, restore best parameters after training.

References

See also

KohonenSOM

Base class with Euclidean distance.

bmu_map(X)[source]

Return BMU grid coordinates using kernel distance.

Parameters:

X (array of shape (n, d))

Returns:

coords

Return type:

array of shape (n, 2) — (row, col) for each sample

fit(X)

Fit KohonenSOM.

predict(X)

Assign each sample to closest prototype (BMU).

class prosemble.models.DKHeskesSOM(grid_height=10, grid_width=10, kernel_sigma=1.0, sigma_init=None, sigma_final=0.5, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Differentiating Kernel Heskes SOM.

Heskes SOM with Gaussian kernel distance. The kernel bandwidth \(\sigma\) is a fixed hyperparameter (not learned). The Heskes BMU criterion and batch update operate in the original data space — only the data-space distance metric changes.

The Heskes BMU criterion uses kernel distance:

\[c^*(x) = \arg\min_c \sum_k h(k, c) \cdot d_\kappa^2(x, w_k)\]
Parameters:
  • kernel_sigma (float) – Gaussian kernel bandwidth for data-space distance. Default: 1.0.

  • grid_height (int) – Height of the 2D grid.

  • grid_width (int) – Width of the 2D grid.

  • sigma_init (float, optional) – Initial grid neighborhood radius.

  • sigma_final (float) – Final grid neighborhood radius.

  • max_iter (int) – Maximum training iterations.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training.

  • patience (int, optional) – Epochs with no improvement before early stopping.

  • restore_best (bool) – If True, restore best parameters after training.

References

See also

HeskesSOM

Base class with Euclidean distance.

bmu_map(X)[source]

Return BMU grid coordinates using Heskes criterion with kernel distance.

Parameters:

X (array of shape (n, d))

Returns:

coords

Return type:

array of shape (n, 2) — (row, col) for each sample

fit(X)

Fit HeskesSOM.

predict(X)

Assign each sample to closest prototype (BMU).

Topology-Preserving Models

class prosemble.models.NeuralGas(n_prototypes, lr_init=0.5, lr_final=0.01, lambda_init=None, lambda_final=0.01, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Neural Gas.

Updates all prototypes based on rank-distance:

\[h(\text{rank}, \lambda) = \exp(-\text{rank} / \lambda)\]
\[w_k \leftarrow w_k + \varepsilon \cdot h(\text{rank}_k) \cdot (x - w_k)\]

Both \(\varepsilon\) and \(\lambda\) decay exponentially during training.

Parameters:
  • lr_init (float) – Initial learning rate.

  • lr_final (float) – Final learning rate.

  • lambda_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final neighborhood range.

  • n_prototypes (int) – Number of prototypes/nodes.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Initial learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • distance_fn (callable, optional) – Distance function.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.

fit(X)[source]

Fit Neural Gas.

predict(X)

Assign each sample to closest prototype (BMU).

transform(X)

Return distance matrix to all prototypes.

class prosemble.models.GrowingNeuralGas(max_nodes=100, lr_winner=0.1, lr_neighbor=0.01, max_age=50, insert_interval=100, error_decay=0.995, n_prototypes=2, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Growing Neural Gas.

Starts with 2 nodes and grows by inserting nodes near the highest-error units. Connections between nodes have ages; old connections are removed.

Parameters:
  • max_nodes (int) – Maximum number of nodes.

  • lr_winner (float) – Learning rate for the winning node.

  • lr_neighbor (float) – Learning rate for neighbors of the winner.

  • max_age (int) – Maximum age before an edge is removed.

  • insert_interval (int) – Insert a new node every this many steps.

  • error_decay (float) – Error decay factor applied to all nodes.

  • n_prototypes (int) – Number of prototypes/nodes.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Initial learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • distance_fn (callable, optional) – Distance function.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.

fit(X)[source]

Fit Growing Neural Gas.

predict(X)

Assign each sample to closest prototype (BMU).

class prosemble.models.KohonenSOM(grid_height=10, grid_width=10, sigma_init=None, sigma_final=0.5, lr_init=0.5, lr_final=0.01, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Standard Kohonen Self-Organizing Map.

Uses squared Euclidean distance for BMU selection, Gaussian neighborhood function, exponential decay for sigma and learning rate, and batch updates.

Parameters:
  • grid_height (int) – Height of the 2D grid.

  • grid_width (int) – Width of the 2D grid.

  • sigma_init (float, optional) – Initial neighborhood radius. Default: max(grid_height, grid_width) / 2.

  • sigma_final (float) – Final neighborhood radius.

  • lr_init (float) – Initial learning rate.

  • lr_final (float) – Final learning rate.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Initial learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • distance_fn (callable, optional) – Distance function.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.

fit(X)[source]

Fit KohonenSOM.

predict(X)

Assign each sample to closest prototype (BMU).

class prosemble.models.HeskesSOM(grid_height=10, grid_width=10, sigma_init=None, sigma_final=0.5, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, use_scan=True, patience=None, restore_best=False)[source]

Heskes Self-Organizing Map.

Uses a modified BMU definition that considers the neighborhood structure, and a pure batch update (weighted average of data). Guarantees monotonic decrease of the Heskes energy function.

Differences from KohonenSOM:

  • BMU is chosen to minimize neighborhood-weighted distance sum, not raw distance to closest prototype.

  • Prototypes are updated via weighted average (no learning rate).

  • Energy is guaranteed to decrease monotonically.

Parameters:
  • grid_height (int) – Height of the 2D grid.

  • grid_width (int) – Width of the 2D grid.

  • sigma_init (float, optional) – Initial neighborhood radius. Default: max(grid_height, grid_width) / 2.

  • sigma_final (float) – Final neighborhood radius.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Initial learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • distance_fn (callable, optional) – Distance function.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – If True (default), use jax.lax.scan for training (faster, JIT-compiled, but runs all max_iter iterations even after convergence). If False, use a Python for-loop with true early stopping.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.

fit(X)[source]

Fit HeskesSOM.

predict(X)

Assign each sample to closest prototype (BMU).

Riemannian Models

class prosemble.models.RiemannianSRNG(manifold, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Riemannian Neural Gas.

Combines three key ideas:

  • GLVQ loss: \((d^+ - d^-) / (d^+ + d^-)\) for margin-based classification

  • Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via \(\exp(-\text{rank} / \gamma)\)

  • Geodesic distance: \(d(x, w)\) computed via the manifold’s intrinsic metric (matrix logarithm + Frobenius norm)

Prototypes live on the manifold and are updated via projected gradient descent: optax computes Euclidean gradients, then manifold.project() maps prototypes back to the manifold after each step.

The neighborhood range \(\gamma\) decays during training from \(\gamma_{\text{init}}\) to \(\gamma_{\text{final}}\). When \(\gamma \to 0\), RiemannianSRNG recovers a Riemannian GLVQ.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.

  • beta (float) – Transfer function steepness parameter for sigmoid shaping.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final.

  • tau (float) – Injectivity radius safety factor for manifold projection. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Optimizer name (‘adam’, ‘sgd’) or optax GradientTransformation. Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping (default: identity).

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan for training (faster, JIT-compiled). If False (default), use a Python for-loop with true early stopping.

  • batch_size (int, optional) – Mini-batch size. If None (default), use full-batch training.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule. Default: None.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler. Default: None.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes. Default: ‘stratified_random’.

  • patience (int, optional) – Number of consecutive epochs with no improvement before stopping. Default: None.

  • restore_best (bool) – If True, restore parameters that achieved the lowest loss. Default: False.

  • class_weight (dict or 'balanced', optional) – Weights for each class. Default: None (uniform).

  • gradient_accumulation_steps (int, optional) – Accumulate gradients over this many steps. Default: None.

  • ema_decay (float, optional) – Exponential moving average decay for parameters. Default: None.

  • freeze_params (list of str, optional) – List of parameter group names to freeze. Default: None.

  • lookahead (dict, optional) – Enable lookahead optimizer wrapper. Default: None.

  • mixed_precision (str or None, optional) – Compute dtype for mixed precision training. Default: None.

predict(X)[source]

Predict class labels using geodesic distance.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianSMNG(manifold, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Riemannian Matrix Neural Gas.

Extends RiemannianSRNG with a global metric adaptation matrix \(\Omega\) applied in the tangent space. The distance is:

\[d(x, w_k) = \|\Omega \cdot \text{Log}_{w_k}(x)_{\text{flat}}\|^2\]

where \(\text{Log}_{w_k}(x)\) is the logarithmic map at prototype \(w_k\), flattened to a vector.

The learned relevance matrix \(\Lambda = \Omega^T \Omega\) captures feature correlations in the tangent space.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • latent_dim (int, optional) – Projection dimensionality for omega. Default: n_features (square).

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor for gamma.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping.

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan. Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation steps.

  • ema_decay (float, optional) – EMA decay for parameters.

  • freeze_params (list of str, optional) – Parameter groups to freeze.

  • lookahead (dict, optional) – Lookahead optimizer config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

predict(X)[source]

Predict using tangent-space omega metric.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

relevance_matrix()[source]

Return learned relevance matrix Lambda = Omega^T Omega.

Return type:

array of shape (d_flat, d_flat)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianSLNG(manifold, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Riemannian Localized Matrix Neural Gas.

Extends RiemannianSRNG with per-prototype metric adaptation. Each prototype \(w_k\) has its own matrix \(\Omega_k\) applied in the tangent space:

\[d(x, w_k) = \|\Omega_k \cdot \text{Log}_{w_k}(x)_{\text{flat}}\|^2\]

Since each \(\Omega_k\) operates on tangent vectors at \(w_k\) (all in the same tangent space \(T_{w_k}M\)), this is geometrically well-defined.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • latent_dim (int, optional) – Projection dimensionality for each omega. Default: n_features.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor for gamma.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping.

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan. Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation steps.

  • ema_decay (float, optional) – EMA decay for parameters.

  • freeze_params (list of str, optional) – Parameter groups to freeze.

  • lookahead (dict, optional) – Lookahead optimizer config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

predict(X)[source]

Predict using per-prototype tangent-space omega metric.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianSTNG(manifold, subspace_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Supervised Riemannian Tangent Neural Gas.

Extends RiemannianSRNG with per-prototype tangent subspace projection. Each prototype \(w_k\) has an orthonormal basis \(\Omega_k\) defining an invariant subspace; the distance measures how far the tangent vector lies outside this subspace:

\[d(x, w_k) = \|(I - \Omega_k \Omega_k^T) \cdot \text{Log}_{w_k}(x)_{\text{flat}}\|^2\]

The subspace bases are re-orthogonalized after each gradient step.

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • subspace_dim (int, optional) – Dimensionality of each prototype’s tangent subspace. Default: n_features // 2.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range. Default: max prototypes per class / 2.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor for gamma.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping.

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan. Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation steps.

  • ema_decay (float, optional) – EMA decay for parameters.

  • freeze_params (list of str, optional) – Parameter groups to freeze.

  • lookahead (dict, optional) – Lookahead optimizer config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

predict(X)[source]

Predict using tangent subspace distance.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianNeuralGas(manifold, n_prototypes, lr_init=0.3, lr_final=0.01, lambda_init=None, lambda_final=0.01, tau=0.9, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, distance_fn=None, callbacks=None, patience=None, restore_best=False)[source]

Riemannian Neural Gas.

Learns prototypes on a Riemannian manifold using geodesic distances and exponential/logarithmic maps for updates.

Parameters:
  • manifold (Manifold) – Manifold instance (SO, SPD, or Grassmannian from prosemble.core.manifolds). Must satisfy the Manifold protocol.

  • n_prototypes (int) – Number of prototypes/nodes.

  • lr_init (float) – Initial learning rate. Default: 0.3.

  • lr_final (float) – Final learning rate. Default: 0.01.

  • lambda_init (float, optional) – Initial neighborhood range. Default: n_prototypes / 2.

  • lambda_final (float) – Final neighborhood range. Default: 0.01.

  • tau (float) – Safety factor for injectivity radius bound. Default: 0.9.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Initial learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • distance_fn (callable, optional) – Distance function.

  • callbacks (list, optional) – Callback objects.

  • patience (int, optional) – Epochs with no improvement before early stopping. Default: None.

  • restore_best (bool) – If True, restore parameters from the lowest-loss epoch. Default: False.

fit(X)[source]

Fit Riemannian Neural Gas.

Parameters:

X (array of shape (n_samples, *point_shape)) – Data points on the manifold.

Return type:

self

predict(X)[source]

Assign each point to the nearest prototype (BMU).

Parameters:

X (array of shape (n_samples, *point_shape))

Returns:

labels

Return type:

array of shape (n_samples,)

transform(X)[source]

Compute geodesic distance matrix to all prototypes.

Parameters:

X (array of shape (n_samples, *point_shape))

Returns:

distances – Geodesic distances (not squared).

Return type:

array of shape (n_samples, n_prototypes)

Riemannian Differentiating Kernel

class prosemble.models.RiemannianDKGLVQ(manifold, sigma_init='median', sigma_min=0.001, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel GLVQ.

Extends RiemannianSRNG with Gaussian kernel distance. Each prototype \(w_k\) has a learnable bandwidth \(\sigma_k\) that controls the sensitivity range of the kernel on the manifold.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{d_{\text{geo}}^2(x, w_k)}{2\sigma_k^2} \right)\right)\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median geodesic distance. ‘mean’: per-class mean geodesic distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • beta (float) – Transfer function steepness parameter.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping.

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan. Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation steps.

  • ema_decay (float, optional) – EMA decay for parameters.

  • freeze_params (list of str, optional) – Parameter groups to freeze.

  • lookahead (dict, optional) – Lookahead optimizer config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSRNG

Base Riemannian supervised Neural Gas.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)[source]

Predict class labels using kernel-wrapped geodesic distance.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKGRLVQ(manifold, sigma_init='median', sigma_min=0.001, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel GRLVQ.

Extends RiemannianSRNG with relevance-weighted Gaussian kernel distance in tangent space. Each prototype \(w_k\) has a learnable bandwidth \(\sigma_k\), and a shared relevance vector \(\lambda\) weights the tangent space features.

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j v_j^2}{2\sigma_k^2} \right)\right)\]

where \(v = \text{Log}_{w_k}(x)_{\text{flat}}\).

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median geodesic distance. ‘mean’: per-class mean geodesic distance. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma to prevent bandwidth collapse. Default: 1e-3.

  • beta (float) – Transfer function steepness parameter.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping.

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan. Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation steps.

  • ema_decay (float, optional) – EMA decay for parameters.

  • freeze_params (list of str, optional) – Parameter groups to freeze.

  • lookahead (dict, optional) – Lookahead optimizer config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSRNG

Base Riemannian supervised Neural Gas.

RiemannianDKGLVQ

Gaussian kernel variant without relevance weighting.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

predict(X)[source]

Predict using relevance-weighted kernel distance in tangent space.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKGMLVQ(manifold, latent_dim=None, omega_hat_scale=0.1, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, tau=0.95, lr_ratio=0.5, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel GMLVQ.

Extends RiemannianSRNG with exponential kernel distance applied in tangent space. A global transformation matrix \(\hat\Omega\) (d_flat x latent_dim) is learned such that:

\[\hat\Lambda = \hat\Omega \hat\Omega^T\]
\[d_\kappa^2(x, w_k) = \exp(v^T \hat\Lambda v) - 1\]

where \(v = \text{Log}_{w_k}(x)_{\text{flat}}\). Since the prototype maps to the zero vector in its own tangent space, the exponential kernel simplifies from the full three-term formula to \(\exp(v^T \hat\Lambda v) - 1\).

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance defining the geometry.

  • latent_dim (int, optional) – Dimensionality of the transformation. If None, uses d_flat.

  • omega_hat_scale (float) – Scale factor for omega_hat initialization. Default: 0.1. Smaller values prevent exp overflow at initialization.

  • beta (float) – Transfer function steepness parameter.

  • gamma_init (float, optional) – Initial neighborhood range for NG cooperation.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step multiplicative decay factor for gamma.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold on loss change.

  • random_seed (int) – Random seed for reproducibility.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function for loss shaping.

  • margin (float) – Margin for loss computation.

  • callbacks (list, optional) – List of Callback objects.

  • use_scan (bool) – If True, use jax.lax.scan. Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – Keyword arguments for the learning rate scheduler.

  • prototypes_initializer (str or callable, optional) – How to initialize prototypes.

  • patience (int, optional) – Epochs with no improvement before stopping.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation steps.

  • ema_decay (float, optional) – EMA decay for parameters.

  • freeze_params (list of str, optional) – Parameter groups to freeze.

  • lookahead (dict, optional) – Lookahead optimizer config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSRNG

Base Riemannian supervised Neural Gas.

RiemannianDKGLVQ

Gaussian kernel variant.

RiemannianDKGRLVQ

Relevance-weighted kernel variant.

property omega_hat_matrix

Return the learned \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

predict(X)[source]

Predict using exponential kernel distance in tangent space.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKSMNG(manifold, sigma_init='median', sigma_min=0.001, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel SMNG.

Extends RiemannianSMNG with a Gaussian kernel wrapping the omega-projected tangent distance. Each prototype has a learnable bandwidth \(\sigma_k\):

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|\Omega \cdot v\|^2}{2\sigma_k^2} \right)\right)\]

where \(v = \text{Log}_{w_k}(x)_{\text{flat}}\).

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median of omega-projected distances. ‘mean’: per-class mean. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • latent_dim (int, optional) – Projection dimensionality for omega. Default: d_flat.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSMNG

Base Riemannian matrix Neural Gas.

RiemannianDKGLVQ

Gaussian kernel on geodesic distance (no omega).

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)[source]

Predict using kernel-wrapped omega-projected tangent distance.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKSLNG(manifold, sigma_init='median', sigma_min=0.001, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel SLNG.

Extends RiemannianSLNG with a Gaussian kernel wrapping the per-prototype omega-projected tangent distance:

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|\Omega_k \cdot v\|^2}{2\sigma_k^2} \right)\right)\]

where \(v = \text{Log}_{w_k}(x)_{\text{flat}}\) and each prototype has its own metric matrix \(\Omega_k\) and bandwidth \(\sigma_k\).

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median of local-omega distances. ‘mean’: per-class mean. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • latent_dim (int, optional) – Projection dimensionality for local omegas. Default: d_flat.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSLNG

Base Riemannian localized matrix Neural Gas.

RiemannianDKSMNG

Global omega + kernel variant.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)[source]

Predict using kernel-wrapped local-omega tangent distance.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKSTNG(manifold, sigma_init='median', sigma_min=0.001, subspace_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel STNG.

Extends RiemannianSTNG with a Gaussian kernel wrapping the tangent subspace residual distance. Each prototype has an orthonormal subspace basis \(\Omega_k\) and a learnable bandwidth \(\sigma_k\):

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\|(I - \Omega_k \Omega_k^T) \cdot v\|^2} {2\sigma_k^2} \right)\right)\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median of subspace residual distances. ‘mean’: per-class mean. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • subspace_dim (int, optional) – Tangent subspace dimensionality. Default: d_flat - 1.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSTNG

Base Riemannian tangent Neural Gas.

RiemannianDKSMNG

Global omega + kernel variant.

RiemannianDKSLNG

Local omega + kernel variant.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

predict(X)[source]

Predict using kernel-wrapped subspace residual distance.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKRSMNG(manifold, sigma_init='median', sigma_min=0.001, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel Relevance SMNG.

Extends RiemannianSMNG with a relevance-weighted Gaussian kernel in the omega-projected tangent space. Each prototype has a learnable bandwidth \(\sigma_k\), and a shared relevance vector \(\lambda\) weights the projected features:

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j (\Omega \cdot v)_j^2} {2\sigma_k^2} \right)\right)\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median of relevance-weighted distances. ‘mean’: per-class mean. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • latent_dim (int, optional) – Projection dimensionality for omega. Default: d_flat.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSMNG

Base Riemannian matrix Neural Gas.

RiemannianDKSMNG

Gaussian kernel variant (no relevance weighting).

RiemannianDKGRLVQ

Relevance kernel on geodesic distance (no omega).

property kernel_bandwidths

Return the learned per-prototype bandwidths.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

predict(X)[source]

Predict using relevance-weighted kernel in omega-projected space.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKMSMNG(manifold, latent_dim=None, kernel_latent_dim=None, omega_hat_scale=0.1, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel Matrix SMNG.

Extends RiemannianSMNG with an exponential kernel in the omega-projected tangent space. The global omega \(\Omega\) projects tangent vectors from d_flat to latent_dim, then a learned matrix \(\hat\Lambda = \hat\Omega \hat\Omega^T\) provides further metric adaptation in the projected space:

\[d_\kappa^2(x, w_k) = \exp\left( (\Omega \cdot v)^T \hat\Lambda (\Omega \cdot v) \right) - 1\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • latent_dim (int, optional) – Projection dimensionality for global omega. Default: d_flat.

  • kernel_latent_dim (int, optional) – Dimensionality for the kernel’s omega_hat. Default: same as omega’s output dimension (latent_dim).

  • omega_hat_scale (float) – Scale for omega_hat initialization. Default: 0.1.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSMNG

Base Riemannian matrix Neural Gas.

RiemannianDKSMNG

Gaussian kernel variant (no matrix kernel).

RiemannianDKRSMNG

Relevance kernel variant.

property omega_hat_matrix

Return the learned kernel \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

predict(X)[source]

Predict using exponential kernel in omega-projected space.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKRSLNG(manifold, sigma_init='median', sigma_min=0.001, latent_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel Relevance SLNG.

Extends RiemannianSLNG with a relevance-weighted Gaussian kernel in the per-prototype omega-projected tangent space. Each prototype has its own metric matrix \(\Omega_k\) and bandwidth \(\sigma_k\), plus a shared relevance vector \(\lambda\):

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j (\Omega_k \cdot v)_j^2} {2\sigma_k^2} \right)\right)\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median of relevance-weighted distances. ‘mean’: per-class mean. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • latent_dim (int, optional) – Projection dimensionality for local omegas. Default: d_flat.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSLNG

Base Riemannian localized matrix Neural Gas.

RiemannianDKSLNG

Gaussian kernel variant (no relevance weighting).

RiemannianDKRSMNG

Global omega + relevance kernel variant.

property kernel_bandwidths

Return the learned per-prototype bandwidths.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

predict(X)[source]

Predict using relevance-weighted kernel in local-omega space.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKMSLNG(manifold, latent_dim=None, kernel_latent_dim=None, omega_hat_scale=0.1, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel Matrix SLNG.

Extends RiemannianSLNG with an exponential kernel in the per-prototype omega-projected tangent space. Each prototype has its own metric matrix \(\Omega_k\), and a shared \(\hat\Lambda = \hat\Omega \hat\Omega^T\) provides further metric adaptation in the projected space:

\[d_\kappa^2(x, w_k) = \exp\left( (\Omega_k \cdot v)^T \hat\Lambda (\Omega_k \cdot v) \right) - 1\]
Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • latent_dim (int, optional) – Projection dimensionality for local omegas. Default: d_flat.

  • kernel_latent_dim (int, optional) – Dimensionality for the kernel’s omega_hat. Default: same as omega’s output dimension (latent_dim).

  • omega_hat_scale (float) – Scale for omega_hat initialization. Default: 0.1.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSLNG

Base Riemannian localized matrix Neural Gas.

RiemannianDKSLNG

Gaussian kernel variant (no matrix kernel).

RiemannianDKRSLNG

Relevance kernel variant.

property omega_hat_matrix

Return the learned kernel \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

predict(X)[source]

Predict using exponential kernel in local-omega-projected space.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKRSTNG(manifold, sigma_init='median', sigma_min=0.001, subspace_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel Relevance STNG.

Extends RiemannianSTNG with a relevance-weighted Gaussian kernel on the tangent subspace residual. Each prototype has an orthonormal subspace basis \(\Omega_k\) and a learnable bandwidth \(\sigma_k\), plus a shared relevance vector \(\lambda\) that weights features of the residual:

\[d_\kappa^2(x, w_k) = 2\left(1 - \exp\left( -\frac{\sum_j \lambda_j r_j^2}{2\sigma_k^2} \right)\right)\]

where \(r = (I - \Omega_k \Omega_k^T) \cdot v\).

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • sigma_init (str or float) – Initialization strategy for per-prototype bandwidths. ‘median’ (default): per-class median of relevance-weighted residuals. ‘mean’: per-class mean. float: fixed value for all prototypes.

  • sigma_min (float) – Lower bound for sigma. Default: 1e-3.

  • subspace_dim (int, optional) – Tangent subspace dimensionality. Default: d_flat - 1.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSTNG

Base Riemannian tangent Neural Gas.

RiemannianDKSTNG

Gaussian kernel variant (no relevance weighting).

property kernel_bandwidths

Return the learned per-prototype bandwidths.

property relevance_profile

Return the learned relevance weights (normalized via softmax).

predict(X)[source]

Predict using relevance-weighted kernel on subspace residual.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

class prosemble.models.RiemannianDKMSTNG(manifold, kernel_latent_dim=None, omega_hat_scale=0.1, subspace_dim=None, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-06, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None)[source]

Riemannian Differentiating Kernel Matrix STNG.

Extends RiemannianSTNG with an exponential kernel on the tangent subspace residual. Each prototype has an orthonormal subspace basis \(\Omega_k\), and a shared \(\hat\Lambda = \hat\Omega \hat\Omega^T\) provides metric adaptation on the residual:

\[d_\kappa^2(x, w_k) = \exp\left( r^T \hat\Lambda r \right) - 1\]

where \(r = (I - \Omega_k \Omega_k^T) \cdot v\).

Parameters:
  • manifold (SO, SPD, or Grassmannian) – Riemannian manifold instance.

  • kernel_latent_dim (int, optional) – Dimensionality for the kernel’s omega_hat. Default: d_flat.

  • omega_hat_scale (float) – Scale for omega_hat initialization. Default: 0.1.

  • subspace_dim (int, optional) – Tangent subspace dimensionality. Default: d_flat - 1.

  • beta (float) – Transfer function steepness.

  • gamma_init (float, optional) – Initial neighborhood range.

  • gamma_final (float) – Final neighborhood range. Default: 0.01.

  • gamma_decay (float, optional) – Per-step decay factor.

  • tau (float) – Injectivity radius safety factor. Default: 0.95.

  • n_prototypes_per_class (int) – Number of prototypes per class.

  • max_iter (int) – Maximum training iterations.

  • lr (float) – Learning rate.

  • epsilon (float) – Convergence threshold.

  • random_seed (int) – Random seed.

  • optimizer (str or optax optimizer, optional) – Default: ‘adam’.

  • transfer_fn (callable, optional) – Transfer function.

  • margin (float) – Margin for loss.

  • callbacks (list, optional) – Callback objects.

  • use_scan (bool) – Default: False.

  • batch_size (int, optional) – Mini-batch size.

  • lr_scheduler (str or optax.Schedule, optional) – Learning rate schedule.

  • lr_scheduler_kwargs (dict, optional) – LR scheduler kwargs.

  • prototypes_initializer (str or callable, optional) – Prototype initialization.

  • patience (int, optional) – Early stopping patience.

  • restore_best (bool) – Restore best parameters. Default: False.

  • class_weight (dict or 'balanced', optional) – Class weights.

  • gradient_accumulation_steps (int, optional) – Gradient accumulation.

  • ema_decay (float, optional) – EMA decay.

  • freeze_params (list of str, optional) – Frozen parameters.

  • lookahead (dict, optional) – Lookahead config.

  • mixed_precision (str or None, optional) – Mixed precision dtype.

References

See also

RiemannianSTNG

Base Riemannian tangent Neural Gas.

RiemannianDKSTNG

Gaussian kernel variant (no matrix kernel).

RiemannianDKRSTNG

Relevance kernel variant.

property omega_hat_matrix

Return the learned kernel \(\hat\Omega\) matrix.

property lambda_hat_matrix

Return \(\hat\Lambda = \hat\Omega \hat\Omega^T\).

predict(X)[source]

Predict using exponential kernel on subspace residual.

Parameters:

X (array-like of shape (n_samples, n_features_flat))

Returns:

labels

Return type:

array of shape (n_samples,)

fit(X, y, initial_prototypes=None, initial_labels=None, validation_data=None, sample_weight=None, resume=False)

Fit the model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data.

  • y (array-like of shape (n_samples,)) – Target labels.

  • initial_prototypes (array-like, optional) – Initial prototype positions. For warm-starting from another model.

  • initial_labels (array-like, optional) – Labels for initial_prototypes. Required when initial_prototypes have a different number than what n_prototypes_per_class produces.

  • validation_data (tuple of (X_val, y_val), optional) – Validation data for monitoring. When provided with restore_best=True, the model restores params with the lowest validation loss.

  • sample_weight (array-like of shape (n_samples,), optional) – Per-sample weights for the loss function.

  • resume (bool, default=False) – If True, resume training from the model’s current fitted state. Uses current prototypes (and other fitted params) as starting point with a fresh optimizer. Cannot be combined with initial_prototypes.

Return type:

self

Utility Models

class prosemble.models.KMeansPlusPlus(n_clusters=3, max_iter=100, epsilon=1e-05, random_seed=None, plot_steps=False)[source]

K-means++ clustering with JAX (Hard C-Means with smart initialization)

K-means++ is an algorithm for choosing initial cluster centers with better convergence properties than random initialization.

Algorithm:

  1. Choose first center uniformly at random from data points

  2. For each data point x, compute D(x), the distance to nearest center

  3. Choose next center with probability proportional to D(x)²

  4. Repeat until k centers are chosen

  5. Run standard k-means (HCM) with these initial centers

Parameters:
  • n_clusters (int, default=3) – Number of clusters

  • max_iter (int, default=100) – Maximum number of iterations

  • epsilon (float, default=1e-5) – Convergence tolerance

  • random_seed (int, optional) – Random seed for reproducibility

  • plot_steps (bool, default=False) – Whether to enable visualization (requires LiveVisualizer)

centroids_

Cluster centers

Type:

array of shape (n_clusters, n_features)

labels_

Labels of each point

Type:

array of shape (n_samples,)

n_iter_

Number of iterations run

Type:

int

objective_

Final objective function value

Type:

float

See also

HCM

Hard C-Means model used internally after initialization.

fit(X)[source]

Fit K-means++ model to data.

Parameters:

X (array of shape (n_samples, n_features)) – Training data

Returns:

self – Fitted estimator

Return type:

object

predict(X)[source]

Predict cluster labels for samples.

Parameters:

X (array of shape (n_samples, n_features)) – New data to predict

Returns:

labels – Index of the cluster each sample belongs to

Return type:

array of shape (n_samples,)

class prosemble.models.KNN(n_neighbors=5)[source]

K-Nearest Neighbors (KNN) with JAX

KNN is a simple, non-parametric classifier that predicts based on the k nearest training samples.

Algorithm:

  1. For each test sample, compute distances to all training samples, find k nearest neighbors, predict label as mode (most common) of k neighbors’ labels, and compute probability as frequency of predicted label.

Parameters:

n_neighbors (int, default=5) – Number of neighbors to use

fit(X, y)[source]

Fit KNN model (store training data)

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data

  • y (array-like of shape (n_samples,)) – Target labels

Return type:

self

predict(X)[source]

Predict class labels for samples

Parameters:

X (array-like of shape (n_samples, n_features)) – Data to predict

Returns:

labels – Predicted class labels

Return type:

array of shape (n_samples,)

class prosemble.models.NPC(n_classes=3, max_iter=10, tol=0.8, random_state=None)[source]

Noise Possibilistic C-Means (NPC) with JAX

NPC is a supervised prototype-based classifier that iteratively optimizes prototypes based on accuracy metric. Uses softmin for probability estimation.

Algorithm:

  1. Initialize prototypes (one per class)

  2. Predict labels using nearest prototype

  3. Compute accuracy

  4. If accuracy >= threshold or max_iter reached, stop

  5. Otherwise, recompute prototypes and repeat

Softmin function: \(\text{softmin}(x_i) = \exp(-x_i) / \sum_j \exp(-x_j)\)

Parameters:
  • n_classes (int) – Number of classes

  • max_iter (int, default=10) – Maximum optimization steps

  • tol (float, default=0.8) – Accuracy threshold for convergence

  • random_state (int, optional) – Random seed for reproducibility

fit(X, y)[source]

Fit NPC model to labeled data

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data

  • y (array-like of shape (n_samples,)) – Target labels (integers from 0 to n_classes-1)

Return type:

self

predict(X)[source]

Predict class labels for samples

Parameters:

X (array-like of shape (n_samples, n_features)) – Data to predict

Returns:

labels – Predicted class labels

Return type:

array of shape (n_samples,)

prosemble.models.Kmeans

alias of KMeansPlusPlus

class prosemble.models.SOM(grid_size=None, max_iter=None, learning_rate=0.5, sigma=1.0, random_state=None)[source]

Self-Organizing Maps (SOM) with JAX

SOM is an unsupervised learning algorithm that creates a low-dimensional (typically 2D) representation of high-dimensional data while preserving topological relationships.

Algorithm:

  1. Initialize grid of neurons with random weights

  2. For each iteration, select random sample from data, find Best Matching Unit (BMU), update BMU and its neighbors towards the sample, and decay learning rate and neighborhood range.

Parameters:
  • grid_size (int, optional) – Size of the SOM grid (grid_size x grid_size). If None, computed as int(sqrt(5 * sqrt(n_samples)))

  • max_iter (int, optional) – Number of training iterations. If None, set to 500 * grid_size^2

  • learning_rate (float, default=0.5) – Initial learning rate

  • sigma (float, default=1.0) – Initial neighborhood radius

  • random_state (int, optional) – Random seed for reproducibility

fit(X)[source]

Fit SOM model to data

Parameters:

X (array-like of shape (n_samples, n_features)) – Training data

Return type:

self

predict(X)[source]

Predict labels for samples (requires fitted label map)

Parameters:

X (array-like of shape (n_samples, n_features)) – Data to predict

Returns:

labels – Predicted labels

Return type:

array of shape (n_samples,)

class prosemble.models.BGPC(n_clusters=3, max_iter=100, tol=0.0001, alpha_init=1.0, beta_init=0.1, beta_final=10.0, init='fcm', random_state=None)[source]

Bayesian Graded Possibilistic C-Means (BGPC) with JAX

BGPC uses exponential weighting with time-decaying alpha and beta parameters.

Algorithm:

  1. Compute membership weights using exponential distance

  2. Normalize memberships using partition function Z

  3. Update centroids as weighted mean of data

  4. Update beta and alpha with decay schedules

  5. Repeat until convergence

Parameters:
  • n_clusters (int) – Number of clusters

  • max_iter (int, default=100) – Maximum number of iterations

  • tol (float, default=1e-4) – Convergence tolerance

  • alpha_init (float, default=1.0) – Initial alpha parameter

  • beta_init (float, default=0.1) – Initial beta parameter (starting value for decay)

  • beta_final (float, default=10.0) – Final beta parameter (ending value for decay)

  • init (str, default='fcm') – Initialization method: ‘random’, ‘fcm’, or ‘kmeans++’

  • random_state (int, optional) – Random seed for reproducibility

fit(X)[source]

Fit BGPC model to data

Parameters:

X (array-like of shape (n_samples, n_features)) – Training data

Return type:

self

predict(X)[source]

Predict cluster labels for samples

Parameters:

X (array-like of shape (n_samples, n_features)) – Data to predict

Returns:

labels – Cluster labels

Return type:

array of shape (n_samples,)