Advanced Features

Prosemble leverages JAX for GPU acceleration, JIT compilation, and functional training loops. This guide covers all advanced features available for supervised prototype models and fuzzy clustering models.

JIT-Compiled Inference

Model inference (predict, predict_proba) is automatically JIT-compiled for maximum speed. The first call triggers compilation; subsequent calls reuse the cached compilation.

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# First call: triggers JIT compilation
preds = model.predict(X_test)

# Subsequent calls: use cached compiled function (fast)
preds = model.predict(X_test)

Training Loop Modes

lax.scan vs Python Loop

By default, supervised models use jax.lax.scan which compiles the entire training loop into a single XLA program. This eliminates Python loop overhead and is the fastest option.

Set use_scan=False to switch to a Python loop, which is required when using callbacks, patience-based early stopping, or validation monitoring.

from prosemble.models import GLVQ

# Default: lax.scan (fastest, no callbacks)
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    use_scan=True,       # default
)
model.fit(X_train, y_train)

# Python loop: needed for callbacks, patience, validation
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    use_scan=False,
)
model.fit(X_train, y_train)

Models that always use Python loops regardless of use_scan:

  • Growing Neural Gas (dynamic topology)

  • Riemannian Neural Gas (manifold operations)

  • Median LVQ (combinatorial M-step)

Fuzzy clustering models use lax.scan by default and fall back to a Python loop when callbacks, patience, or restore_best are enabled.

Mini-Batch Training

For large datasets, supervised models support mini-batch training. Each iteration samples a random batch from the training data:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    batch_size=64,
)
model.fit(X_large, y_large)

Online / Incremental Learning

Train on streaming data or update an existing model with new batches using partial_fit(). The model must be fitted first via fit(), then each partial_fit() call performs a single gradient update while preserving optimizer state across calls.

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)

# Initial training
model.fit(X_train, y_train)

# Incremental updates with new data batches
for X_batch, y_batch in data_stream:
    model.partial_fit(X_batch, y_batch)

# Prototypes and optimizer state are preserved across calls
preds = model.predict(X_test)

All supervised models support partial_fit().

Sample Weighting and Class Balancing

Per-Sample Weights

Pass sample_weight to fit() to assign different importance to individual training samples. Samples with higher weights have more influence on the loss:

import jax.numpy as jnp
from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)

# Per-sample weights (e.g., higher weight for difficult samples)
weights = jnp.array([1.0, 1.0, 2.0, 1.0, 3.0, ...])
model.fit(X_train, y_train, sample_weight=weights)

Class Balancing

For imbalanced datasets, use class_weight to automatically compute per-sample weights inversely proportional to class frequency:

from prosemble.models import GLVQ

# Automatic balanced weighting
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    class_weight='balanced',
)
model.fit(X_train, y_train)

# Or specify weights per class manually
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    class_weight={0: 1.0, 1: 5.0, 2: 1.0},
)
model.fit(X_train, y_train)

With class_weight='balanced', the weight for class c is computed as:

\[w_c = \frac{N}{K \cdot N_c}\]

where N is the total number of samples, K is the number of classes, and N_c is the number of samples in class c.

Validation and Early Stopping

Validation Monitoring

Pass validation data to fit() to monitor performance on held-out samples during training. Combined with restore_best=True, the model automatically restores the parameters that achieved the lowest validation loss:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    restore_best=True,
    use_scan=False,      # required for validation monitoring
)
model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
)

# model.prototypes_ now holds the best params from training
print(f"Best validation loss: {model.best_loss_}")

Patience-Based Early Stopping

Stop training early when no improvement is observed for a given number of consecutive iterations:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=500,
    lr=0.01,
    patience=20,
    restore_best=True,
    use_scan=False,
)
model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
)

print(f"Stopped at iteration {model.n_iter_}")

Fuzzy clustering models also support patience-based early stopping:

from prosemble.models import FCM

model = FCM(
    n_clusters=3,
    max_iter=500,
    patience=10,
    restore_best=True,
)
model.fit(X)

print(f"Stopped at iteration {model.n_iter_}")

Optimizer Configuration

Prosemble supports 26 optimizers via optax. Pass a string name or a pre-built optax.GradientTransformation to the optimizer parameter:

from prosemble.models import GLVQ

# String name (26 options)
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    optimizer='adamw',
)
model.fit(X_train, y_train)

Available optimizer strings: adam, adamw, adamax, adamaxw, adan, adabelief, amsgrad, radam, lamb, lion, novograd, sgd, sign_sgd, signum, noisy_sgd, lars, rmsprop, adagrad, adadelta, adafactor, sm3, yogi, rprop, fromage, lbfgs, dpsgd.

You can also pass a custom optax optimizer directly:

import optax

custom_opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=0.001),
)

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    optimizer=custom_opt,
)
model.fit(X_train, y_train)

Learning Rate Scheduling

Decay or warm up the learning rate during training using the lr_scheduler parameter. Nine built-in schedules are available:

from prosemble.models import GLVQ

# Cosine decay: lr decays smoothly from 0.01 to 0 over training
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lr_scheduler='cosine_decay',
)
model.fit(X_train, y_train)

# Exponential decay with custom decay rate
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lr_scheduler='exponential_decay',
    lr_scheduler_kwargs={'decay_rate': 0.95, 'transition_steps': 1},
)
model.fit(X_train, y_train)

# Warmup then cosine decay
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lr_scheduler='warmup_cosine_decay',
    lr_scheduler_kwargs={
        'warmup_steps': 20,
        'peak_value': 0.01,
        'end_value': 0.0,
    },
)
model.fit(X_train, y_train)

Available schedules:

Schedule

Description

exponential_decay

Exponential decay: lr * decay_rate^(step / transition_steps)

cosine_decay

Cosine annealing from lr to 0

warmup_cosine_decay

Linear warmup then cosine decay

warmup_exponential_decay

Linear warmup then exponential decay

warmup_constant

Linear warmup then constant learning rate

polynomial

Polynomial decay with configurable power

linear

Linear decay from lr to end_value

piecewise_constant

Step-wise constant schedule with boundaries

sgdr

Cosine annealing with warm restarts (SGDR)

Lookahead Optimizer

Lookahead maintains two sets of weights — “fast” weights updated every step, and “slow” weights updated by interpolating toward the fast weights every k steps. This improves generalization and reduces variance:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lookahead={
        'sync_period': 6,         # sync slow weights every 6 steps
        'slow_step_size': 0.5,    # interpolation factor
    },
    use_scan=False,               # required for lookahead
)
model.fit(X_train, y_train)

After training, the slow weights (which generalize better) are used for inference.

Gradient Accumulation

Accumulate gradients over multiple steps before applying an update, effectively increasing the batch size without increasing memory usage:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    batch_size=16,
    gradient_accumulation_steps=4,  # effective batch size = 16 * 4 = 64
)
model.fit(X_train, y_train)

Parameter Freezing

Freeze specific parameters during training. Frozen parameters receive zero gradients and remain at their initial values:

from prosemble.models import GMLVQ

model = GMLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    freeze_params=['omega'],       # freeze the metric matrix
)
model.fit(X_train, y_train)

# Only prototypes were updated; omega stayed at its initial value

Common use cases:

  • Freeze 'omega' in GMLVQ to train prototypes with a fixed metric

  • Freeze 'prototypes' to learn only the metric adaptation matrix

  • Two-phase training: first train prototypes, then freeze them and train omega

Exponential Moving Average

Maintain an exponential moving average of parameters during training. EMA parameters often generalize better than the final training parameters:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    ema_decay=0.999,
    use_scan=False,
)
model.fit(X_train, y_train)

The EMA update rule at each step is:

\[\theta_{\text{ema}} = \alpha \cdot \theta_{\text{ema}} + (1 - \alpha) \cdot \theta\]

where \(\alpha\) is the ema_decay factor.

Mixed Precision

Supervised models support built-in mixed precision training via the mixed_precision parameter. Master weights stay in float32; the forward/backward pass runs in lower precision for faster computation and lower memory on GPU.

import jax.numpy as jnp
from prosemble.models import GLVQ

# bfloat16 — recommended for training stability
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    mixed_precision='bfloat16',
)
model.fit(X_train, y_train)

# float16 — maximum speed, uses loss scaling to prevent underflow
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    mixed_precision='float16',
)
model.fit(X_train, y_train)

# Prototypes remain in float32
assert model.prototypes_.dtype == jnp.float32

Checkpointing and Resume

Save and load models for persistence. All fitted parameters, optimizer state, and hyperparameters are preserved:

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Save
model.save('my_model.npz')

# Load
loaded = GLVQ.load('my_model.npz')

# Resume training from checkpoint
loaded.fit(X_train, y_train, resume=True)

Model Quantization

Reduce model size for deployment by quantizing prototypes and parameters when saving. The model in memory is unchanged; only the saved file is quantized:

model.fit(X_train, y_train)

# Quantize to int8 on save (smallest file size)
model.save('model_int8.npz', quantize='int8')

# Quantize to float16 on save (balanced)
model.save('model_f16.npz', quantize='float16')

# Load quantized model
loaded = GLVQ.load('model_int8.npz')

Export for Deployment

Export a JIT-compiled prediction function using jax.export for deployment without the full model or prosemble dependency:

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Export compiled predict function for batch size 32
exported = model.export_predict(batch_size=32)

# Serialize to bytes
blob = exported.serialize()

# Later, in a separate process without prosemble:
import jax
loaded = jax.export.deserialize(blob)
preds = loaded.call(X_batch)

The exported function contains only the compiled XLA computation and the frozen prototype/metric parameters — no Python dependencies needed at inference time.

ONNX Export

Export fitted models to ONNX for deployment without JAX or prosemble. 68 of 71 models are supported, including encoder models (MLP/CNN backbones), one-class classifiers, and fuzzy clustering:

from prosemble.core.onnx_export import export_onnx

model.fit(X_train, y_train)
onnx_model = export_onnx(model, path='model.onnx')

# Run with ONNX Runtime
import onnxruntime as ort
session = ort.InferenceSession('model.onnx')
preds = session.run(None, {'X': X_test_np})[0]

See the full guide: ONNX Export.

Prototype Analysis

Prototype Win Ratios

Analyze how often each prototype wins on correctly classified samples. This helps identify “dead” prototypes that never win and may be candidates for removal or reinitialization:

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=3, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

ratios = model.prototype_win_ratios(X_train, y_train)

for i, r in enumerate(ratios):
    label = model.prototype_labels_[i]
    print(f"Prototype {i} (class {label}): win ratio {r:.3f}")

A win ratio of 0.0 means the prototype never won on any correctly classified sample and is effectively unused.

GPU Acceleration

JAX automatically uses available GPUs. Check your device:

import jax
print(jax.devices())        # [GpuDevice(id=0)]
print(jax.default_backend())  # 'gpu' or 'cpu'

All prosemble operations (training, inference, distance computation) run on whatever device JAX is configured to use. No code changes needed.

Multi-Device Data Parallelism

For large datasets or image models with trainable backbones, prosemble supports data-parallel training across multiple GPUs or TPUs using JAX’s modern sharding API.

import jax
from prosemble.models import GLVQ

# Use all available devices
devices = jax.devices()  # e.g., [GpuDevice(id=0), GpuDevice(id=1), ...]

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    devices=devices,
    use_scan=False,
)
model.fit(X_train, y_train)

When devices is specified:

  • Training data is sharded along the batch dimension across devices

  • Model parameters and optimizer state are replicated on all devices

  • Gradients are automatically aggregated (all-reduced) across devices

  • After training, parameters are moved back to a single device for inference

Requirements:

  • batch_size (if set) must be divisible by the number of devices

  • The number of training samples should be divisible by the number of devices

  • use_scan=False is recommended for multi-device training

Primary use case: Image variants (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC) with trainable CNN backbones on large datasets across multiple GPUs/TPUs.

from prosemble.models import ImageGLVQ

model = ImageGLVQ(
    n_prototypes_per_class=1,
    max_iter=50,
    lr=0.001,
    devices=jax.devices(),
    gradient_checkpointing=True,  # save memory for deep backbones
    use_scan=False,
)
model.fit(X_images, y_labels)

partial_fit() also works with multi-device models:

model.fit(X_batch1, y_batch1)
model.partial_fit(X_batch2, y_batch2)  # data is automatically sharded

Gradient Checkpointing

Gradient checkpointing (jax.remat) trades compute for memory by recomputing forward activations during the backward pass instead of storing them. This is beneficial when training models with deep backbones (Image, Siamese variants) that would otherwise exhaust GPU memory.

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    gradient_checkpointing=True,
)
model.fit(X_train, y_train)

When gradient_checkpointing=True:

  • Forward activations are not stored during the forward pass

  • During backpropagation, the forward pass is recomputed to obtain activations needed for gradient computation

  • Memory usage is reduced at the cost of ~33% more compute

  • Results are numerically identical to standard training

This option has no effect on models with shallow loss functions (e.g., standard GLVQ with Euclidean distance). It provides significant memory savings when unfreezing deep CNN backbones in Image or Siamese variants.

Custom Gradient Rules

Prosemble provides infrastructure for subclasses to define custom backward passes using jax.custom_vjp. This is useful for:

  • Optimal transport distances with non-differentiable sorting operations

  • Hard thresholding operations that need surrogate gradients

  • Distances with numerically unstable standard autodiff

Subclasses opt in by setting self._use_custom_vjp = True in their __init__ and overriding the _custom_vjp_loss method:

class MyModel(GLVQ):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._use_custom_vjp = True

    def _custom_vjp_loss(self, params, X, y, proto_labels):
        # Implement custom forward + backward pass
        ...

When _use_custom_vjp is active, gradient checkpointing is automatically disabled (custom VJP already controls what gets saved/recomputed).

Training Callbacks

Monitor training with callbacks. Training history is automatically tracked for all models:

model.fit(X_train, y_train)

# Supervised models
print(model.loss_history_)

# Clustering models
print(model.objective_history_)

Live Visualization

Enable real-time training visualization for clustering models:

from prosemble.models import FCM

model = FCM(
    n_clusters=3,
    max_iter=100,
    plot_steps=True,
)
model.fit(X)