Advanced Features¶
Prosemble leverages JAX for GPU acceleration, JIT compilation, and functional training loops. This guide covers all advanced features available for supervised prototype models and fuzzy clustering models.
JIT-Compiled Inference¶
Model inference (predict, predict_proba) is automatically JIT-compiled
for maximum speed. The first call triggers compilation; subsequent calls
reuse the cached compilation.
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
# First call: triggers JIT compilation
preds = model.predict(X_test)
# Subsequent calls: use cached compiled function (fast)
preds = model.predict(X_test)
Training Loop Modes¶
lax.scan vs Python Loop¶
By default, supervised models use jax.lax.scan which compiles the
entire training loop into a single XLA program. This eliminates Python
loop overhead and is the fastest option.
Set use_scan=False to switch to a Python loop, which is required
when using callbacks, patience-based early stopping, or validation
monitoring.
from prosemble.models import GLVQ
# Default: lax.scan (fastest, no callbacks)
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
use_scan=True, # default
)
model.fit(X_train, y_train)
# Python loop: needed for callbacks, patience, validation
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
use_scan=False,
)
model.fit(X_train, y_train)
Models that always use Python loops regardless of use_scan:
Growing Neural Gas (dynamic topology)
Riemannian Neural Gas (manifold operations)
Median LVQ (combinatorial M-step)
Fuzzy clustering models use lax.scan by default and fall back to a
Python loop when callbacks, patience, or restore_best are enabled.
Mini-Batch Training¶
For large datasets, supervised models support mini-batch training. Each iteration samples a random batch from the training data:
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
batch_size=64,
)
model.fit(X_large, y_large)
Online / Incremental Learning¶
Train on streaming data or update an existing model with new batches
using partial_fit(). The model must be fitted first via fit(),
then each partial_fit() call performs a single gradient update while
preserving optimizer state across calls.
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
# Initial training
model.fit(X_train, y_train)
# Incremental updates with new data batches
for X_batch, y_batch in data_stream:
model.partial_fit(X_batch, y_batch)
# Prototypes and optimizer state are preserved across calls
preds = model.predict(X_test)
All supervised models support partial_fit().
Sample Weighting and Class Balancing¶
Per-Sample Weights¶
Pass sample_weight to fit() to assign different importance to
individual training samples. Samples with higher weights have more
influence on the loss:
import jax.numpy as jnp
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
# Per-sample weights (e.g., higher weight for difficult samples)
weights = jnp.array([1.0, 1.0, 2.0, 1.0, 3.0, ...])
model.fit(X_train, y_train, sample_weight=weights)
Class Balancing¶
For imbalanced datasets, use class_weight to automatically compute
per-sample weights inversely proportional to class frequency:
from prosemble.models import GLVQ
# Automatic balanced weighting
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
class_weight='balanced',
)
model.fit(X_train, y_train)
# Or specify weights per class manually
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
class_weight={0: 1.0, 1: 5.0, 2: 1.0},
)
model.fit(X_train, y_train)
With class_weight='balanced', the weight for class c is computed as:
where N is the total number of samples, K is the number of classes, and N_c is the number of samples in class c.
Validation and Early Stopping¶
Validation Monitoring¶
Pass validation data to fit() to monitor performance on held-out
samples during training. Combined with restore_best=True, the model
automatically restores the parameters that achieved the lowest
validation loss:
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
restore_best=True,
use_scan=False, # required for validation monitoring
)
model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
)
# model.prototypes_ now holds the best params from training
print(f"Best validation loss: {model.best_loss_}")
Patience-Based Early Stopping¶
Stop training early when no improvement is observed for a given number of consecutive iterations:
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=500,
lr=0.01,
patience=20,
restore_best=True,
use_scan=False,
)
model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
)
print(f"Stopped at iteration {model.n_iter_}")
Fuzzy clustering models also support patience-based early stopping:
from prosemble.models import FCM
model = FCM(
n_clusters=3,
max_iter=500,
patience=10,
restore_best=True,
)
model.fit(X)
print(f"Stopped at iteration {model.n_iter_}")
Optimizer Configuration¶
Prosemble supports 26 optimizers via optax. Pass a string name or a
pre-built optax.GradientTransformation to the optimizer parameter:
from prosemble.models import GLVQ
# String name (26 options)
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
optimizer='adamw',
)
model.fit(X_train, y_train)
Available optimizer strings: adam, adamw, adamax, adamaxw,
adan, adabelief, amsgrad, radam, lamb, lion,
novograd, sgd, sign_sgd, signum, noisy_sgd, lars,
rmsprop, adagrad, adadelta, adafactor, sm3, yogi,
rprop, fromage, lbfgs, dpsgd.
You can also pass a custom optax optimizer directly:
import optax
custom_opt = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=0.001),
)
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
optimizer=custom_opt,
)
model.fit(X_train, y_train)
Learning Rate Scheduling¶
Decay or warm up the learning rate during training using the
lr_scheduler parameter. Nine built-in schedules are available:
from prosemble.models import GLVQ
# Cosine decay: lr decays smoothly from 0.01 to 0 over training
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
lr_scheduler='cosine_decay',
)
model.fit(X_train, y_train)
# Exponential decay with custom decay rate
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
lr_scheduler='exponential_decay',
lr_scheduler_kwargs={'decay_rate': 0.95, 'transition_steps': 1},
)
model.fit(X_train, y_train)
# Warmup then cosine decay
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
lr_scheduler='warmup_cosine_decay',
lr_scheduler_kwargs={
'warmup_steps': 20,
'peak_value': 0.01,
'end_value': 0.0,
},
)
model.fit(X_train, y_train)
Available schedules:
Schedule |
Description |
|---|---|
|
Exponential decay: |
|
Cosine annealing from |
|
Linear warmup then cosine decay |
|
Linear warmup then exponential decay |
|
Linear warmup then constant learning rate |
|
Polynomial decay with configurable power |
|
Linear decay from |
|
Step-wise constant schedule with boundaries |
|
Cosine annealing with warm restarts (SGDR) |
Lookahead Optimizer¶
Lookahead maintains two sets of weights — “fast” weights updated every step, and “slow” weights updated by interpolating toward the fast weights every k steps. This improves generalization and reduces variance:
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
lookahead={
'sync_period': 6, # sync slow weights every 6 steps
'slow_step_size': 0.5, # interpolation factor
},
use_scan=False, # required for lookahead
)
model.fit(X_train, y_train)
After training, the slow weights (which generalize better) are used for inference.
Gradient Accumulation¶
Accumulate gradients over multiple steps before applying an update, effectively increasing the batch size without increasing memory usage:
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
batch_size=16,
gradient_accumulation_steps=4, # effective batch size = 16 * 4 = 64
)
model.fit(X_train, y_train)
Parameter Freezing¶
Freeze specific parameters during training. Frozen parameters receive zero gradients and remain at their initial values:
from prosemble.models import GMLVQ
model = GMLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
freeze_params=['omega'], # freeze the metric matrix
)
model.fit(X_train, y_train)
# Only prototypes were updated; omega stayed at its initial value
Common use cases:
Freeze
'omega'in GMLVQ to train prototypes with a fixed metricFreeze
'prototypes'to learn only the metric adaptation matrixTwo-phase training: first train prototypes, then freeze them and train omega
Exponential Moving Average¶
Maintain an exponential moving average of parameters during training. EMA parameters often generalize better than the final training parameters:
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
ema_decay=0.999,
use_scan=False,
)
model.fit(X_train, y_train)
The EMA update rule at each step is:
where \(\alpha\) is the ema_decay factor.
Mixed Precision¶
Supervised models support built-in mixed precision training via the
mixed_precision parameter. Master weights stay in float32; the
forward/backward pass runs in lower precision for faster computation
and lower memory on GPU.
import jax.numpy as jnp
from prosemble.models import GLVQ
# bfloat16 — recommended for training stability
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
mixed_precision='bfloat16',
)
model.fit(X_train, y_train)
# float16 — maximum speed, uses loss scaling to prevent underflow
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
mixed_precision='float16',
)
model.fit(X_train, y_train)
# Prototypes remain in float32
assert model.prototypes_.dtype == jnp.float32
Checkpointing and Resume¶
Save and load models for persistence. All fitted parameters, optimizer state, and hyperparameters are preserved:
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
# Save
model.save('my_model.npz')
# Load
loaded = GLVQ.load('my_model.npz')
# Resume training from checkpoint
loaded.fit(X_train, y_train, resume=True)
Model Quantization¶
Reduce model size for deployment by quantizing prototypes and parameters when saving. The model in memory is unchanged; only the saved file is quantized:
model.fit(X_train, y_train)
# Quantize to int8 on save (smallest file size)
model.save('model_int8.npz', quantize='int8')
# Quantize to float16 on save (balanced)
model.save('model_f16.npz', quantize='float16')
# Load quantized model
loaded = GLVQ.load('model_int8.npz')
Export for Deployment¶
Export a JIT-compiled prediction function using jax.export for
deployment without the full model or prosemble dependency:
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
# Export compiled predict function for batch size 32
exported = model.export_predict(batch_size=32)
# Serialize to bytes
blob = exported.serialize()
# Later, in a separate process without prosemble:
import jax
loaded = jax.export.deserialize(blob)
preds = loaded.call(X_batch)
The exported function contains only the compiled XLA computation and the frozen prototype/metric parameters — no Python dependencies needed at inference time.
ONNX Export¶
Export fitted models to ONNX for deployment without JAX or prosemble. 68 of 71 models are supported, including encoder models (MLP/CNN backbones), one-class classifiers, and fuzzy clustering:
from prosemble.core.onnx_export import export_onnx
model.fit(X_train, y_train)
onnx_model = export_onnx(model, path='model.onnx')
# Run with ONNX Runtime
import onnxruntime as ort
session = ort.InferenceSession('model.onnx')
preds = session.run(None, {'X': X_test_np})[0]
See the full guide: ONNX Export.
Prototype Analysis¶
Prototype Win Ratios¶
Analyze how often each prototype wins on correctly classified samples. This helps identify “dead” prototypes that never win and may be candidates for removal or reinitialization:
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=3, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
ratios = model.prototype_win_ratios(X_train, y_train)
for i, r in enumerate(ratios):
label = model.prototype_labels_[i]
print(f"Prototype {i} (class {label}): win ratio {r:.3f}")
A win ratio of 0.0 means the prototype never won on any correctly classified sample and is effectively unused.
GPU Acceleration¶
JAX automatically uses available GPUs. Check your device:
import jax
print(jax.devices()) # [GpuDevice(id=0)]
print(jax.default_backend()) # 'gpu' or 'cpu'
All prosemble operations (training, inference, distance computation) run on whatever device JAX is configured to use. No code changes needed.
Multi-Device Data Parallelism¶
For large datasets or image models with trainable backbones, prosemble supports data-parallel training across multiple GPUs or TPUs using JAX’s modern sharding API.
import jax
from prosemble.models import GLVQ
# Use all available devices
devices = jax.devices() # e.g., [GpuDevice(id=0), GpuDevice(id=1), ...]
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
devices=devices,
use_scan=False,
)
model.fit(X_train, y_train)
When devices is specified:
Training data is sharded along the batch dimension across devices
Model parameters and optimizer state are replicated on all devices
Gradients are automatically aggregated (all-reduced) across devices
After training, parameters are moved back to a single device for inference
Requirements:
batch_size(if set) must be divisible by the number of devicesThe number of training samples should be divisible by the number of devices
use_scan=Falseis recommended for multi-device training
Primary use case: Image variants (ImageGLVQ, ImageGMLVQ,
ImageGTLVQ, ImageCBC) with trainable CNN backbones on large datasets
across multiple GPUs/TPUs.
from prosemble.models import ImageGLVQ
model = ImageGLVQ(
n_prototypes_per_class=1,
max_iter=50,
lr=0.001,
devices=jax.devices(),
gradient_checkpointing=True, # save memory for deep backbones
use_scan=False,
)
model.fit(X_images, y_labels)
partial_fit() also works with multi-device models:
model.fit(X_batch1, y_batch1)
model.partial_fit(X_batch2, y_batch2) # data is automatically sharded
Gradient Checkpointing¶
Gradient checkpointing (jax.remat) trades compute for memory by
recomputing forward activations during the backward pass instead of
storing them. This is beneficial when training models with deep backbones
(Image, Siamese variants) that would otherwise exhaust GPU memory.
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
gradient_checkpointing=True,
)
model.fit(X_train, y_train)
When gradient_checkpointing=True:
Forward activations are not stored during the forward pass
During backpropagation, the forward pass is recomputed to obtain activations needed for gradient computation
Memory usage is reduced at the cost of ~33% more compute
Results are numerically identical to standard training
This option has no effect on models with shallow loss functions (e.g., standard GLVQ with Euclidean distance). It provides significant memory savings when unfreezing deep CNN backbones in Image or Siamese variants.
Custom Gradient Rules¶
Prosemble provides infrastructure for subclasses to define custom
backward passes using jax.custom_vjp. This is useful for:
Optimal transport distances with non-differentiable sorting operations
Hard thresholding operations that need surrogate gradients
Distances with numerically unstable standard autodiff
Subclasses opt in by setting self._use_custom_vjp = True in their
__init__ and overriding the _custom_vjp_loss method:
class MyModel(GLVQ):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._use_custom_vjp = True
def _custom_vjp_loss(self, params, X, y, proto_labels):
# Implement custom forward + backward pass
...
When _use_custom_vjp is active, gradient checkpointing is
automatically disabled (custom VJP already controls what gets
saved/recomputed).
Training Callbacks¶
Monitor training with callbacks. Training history is automatically tracked for all models:
model.fit(X_train, y_train)
# Supervised models
print(model.loss_history_)
# Clustering models
print(model.objective_history_)
Live Visualization¶
Enable real-time training visualization for clustering models:
from prosemble.models import FCM
model = FCM(
n_clusters=3,
max_iter=100,
plot_steps=True,
)
model.fit(X)
Custom LVQ Optimizers¶
Prosemble provides three custom optax-compatible optimizers designed
specifically for the geometry and parameter structure of LVQ models.
These can be composed with standard optax transformations via
optax.chain().
Per-Group Gradient Clipping¶
Different parameter types (prototypes, omega matrices, relevances) have different natural scales. A single global clip either under-constrains large parameters or over-constrains small ones. Per-group clipping clips each parameter group independently:
import optax
from prosemble.core.optimizers import per_group_clip
optimizer = optax.chain(
per_group_clip({'prototypes': 1.0, 'omega': 0.5, 'sigmas': 0.1}),
optax.adam(0.01),
)
from prosemble.models import GMLVQ
model = GMLVQ(
n_prototypes_per_class=2,
max_iter=100,
lr=0.01,
optimizer=optimizer,
)
model.fit(X_train, y_train)
Hypergradient Descent¶
Adaptive per-parameter learning rates via gradient correlation (Baydin et al. 2017). If consecutive gradients point in the same direction, the learning rate increases; if they oscillate, it decreases:
from prosemble.core.optimizers import hypergradient_descent
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
optimizer=hypergradient_descent(
init_lr=0.01,
hyper_lr=1e-4,
min_lr=1e-6,
max_lr=0.1,
),
)
model.fit(X_train, y_train)
Riemannian Nesterov Momentum¶
Nesterov accelerated gradient providing \(O(1/t^2)\) convergence
rate versus \(O(1/t)\) for vanilla gradient descent. Designed for
use with Riemannian models where manifold projection is handled by
_post_update():
from prosemble.core.optimizers import riemannian_nesterov
from prosemble.models import GLVQ
model = GLVQ(
n_prototypes_per_class=2,
max_iter=200,
lr=0.01,
optimizer=riemannian_nesterov(
learning_rate=0.01,
momentum=0.9,
),
)
model.fit(X_train, y_train)
Reject Option¶
The RejectOptionMixin adds calibrated rejection to any supervised
prototype classifier. When the model is uncertain about a prediction
(the sample lies near the decision boundary), it abstains instead of
guessing.
The confidence measure is the GLVQ relative margin:
Values near 0 indicate the decision boundary; values near 1 indicate high confidence.
from prosemble.models import GLVQ
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
# Predict with rejection (uncertain samples get label -1)
labels = model.predict_with_rejection(X_test, threshold=0.1)
rejected = labels == -1
print(f"Rejected {rejected.sum()} of {len(labels)} samples")
# Compute confidence scores
conf = model.confidence(X_test)
# Find the optimal threshold minimizing Chow's risk
threshold = model.optimal_threshold(
X_val, y_val,
cost_reject=0.5, # rejection costs half of a misclassification
cost_error=1.0,
)
# Accuracy-coverage curve
thresholds, accuracies, coverages = model.accuracy_coverage_curve(X_val, y_val)
Curriculum Learning¶
Self-paced curriculum learning adaptively weights training samples by difficulty. Early in training, the model focuses on easy samples; as training progresses, harder samples are gradually introduced.
The difficulty measure is the per-sample GLVQ loss (mu-ratio). Samples with loss below the current threshold are weighted; those above are down-weighted or excluded.
from prosemble.core.curriculum import (
curriculum_weights, curriculum_threshold, apply_curriculum_to_loss,
)
# Compute per-sample difficulty weights
threshold = curriculum_threshold(
iteration=50,
max_iter=200,
init_threshold=0.3, # initially only easy samples
final_threshold=1.0, # eventually all samples
schedule='linear',
)
weights = curriculum_weights(per_sample_loss, threshold, mode='soft')
# Or use the combined pipeline
weighted_loss = apply_curriculum_to_loss(
per_sample_losses,
iteration=50,
max_iter=200,
init_threshold=0.3,
final_threshold=1.0,
)
Three weighting modes are available:
'hard': binary — include (weight=1) or exclude (weight=0)'soft': smooth sigmoid transition at the threshold'linear': linearly decrease weight from 1 to 0 as loss increases
Three scheduling strategies:
'linear': threshold increases linearly with iteration'exponential': slow start, fast growth'cosine': smooth cosine schedule from init to final threshold
Prototype Regularization¶
Prototype Diversity (DPP)¶
Prevent same-class prototypes from collapsing onto each other using a determinantal point process (DPP) inspired penalty. The log-determinant of the distance kernel matrix encourages spread:
from prosemble.core.regularization import prototype_diversity_loss
# Add as a regularization term to the loss
div_loss = prototype_diversity_loss(
prototypes, proto_labels,
sigma_div=1.0,
)
total_loss = classification_loss + 0.01 * div_loss
Sparse Relevance Regularization¶
For relevance learning models (GRLVQ), encourage sparsity in the learned relevance profile via proximal gradient operators:
from prosemble.core.regularization import (
sparse_relevance_proximal,
elastic_net_proximal,
)
# L1 soft-thresholding (promotes exact zeros)
sparse_rel = sparse_relevance_proximal(
relevances, l1_weight=0.01, lr=0.01,
)
# Elastic net (L1 + L2 combination)
sparse_rel = elastic_net_proximal(
relevances, l1_weight=0.01, l2_weight=0.001, lr=0.01,
)
Geodesic Interpolation¶
For Riemannian models, geodesic utilities enable visualization and analysis of decision boundaries on curved manifolds.
from prosemble.core.geodesic import (
geodesic_interpolation,
geodesic_midpoint,
decision_boundary_point,
prototype_geodesic_distances,
inter_class_geodesics,
)
from prosemble.core.manifolds import SO
manifold = SO(3)
# Compute geodesic path between two SO(3) prototypes
path = geodesic_interpolation(manifold, proto_a, proto_b, n_points=50)
# Find the decision boundary along the geodesic
boundary_pt, t_boundary = decision_boundary_point(
manifold, proto_a, proto_b,
)
print(f"Boundary at t={t_boundary:.3f}") # t=0.5 for symmetric manifolds
# Pairwise geodesic distance matrix between all prototypes
dist_matrix = prototype_geodesic_distances(
manifold, prototypes, proto_labels,
)
# All inter-class geodesics with boundary locations
geodesics = inter_class_geodesics(
manifold, prototypes, proto_labels,
)