Advanced Features

Prosemble leverages JAX for GPU acceleration, JIT compilation, and functional training loops. This guide covers all advanced features available for supervised prototype models and fuzzy clustering models.

JIT-Compiled Inference

Model inference (predict, predict_proba) is automatically JIT-compiled for maximum speed. The first call triggers compilation; subsequent calls reuse the cached compilation.

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# First call: triggers JIT compilation
preds = model.predict(X_test)

# Subsequent calls: use cached compiled function (fast)
preds = model.predict(X_test)

Training Loop Modes

lax.scan vs Python Loop

By default, supervised models use jax.lax.scan which compiles the entire training loop into a single XLA program. This eliminates Python loop overhead and is the fastest option.

Set use_scan=False to switch to a Python loop, which is required when using callbacks, patience-based early stopping, or validation monitoring.

from prosemble.models import GLVQ

# Default: lax.scan (fastest, no callbacks)
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    use_scan=True,       # default
)
model.fit(X_train, y_train)

# Python loop: needed for callbacks, patience, validation
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    use_scan=False,
)
model.fit(X_train, y_train)

Models that always use Python loops regardless of use_scan:

  • Growing Neural Gas (dynamic topology)

  • Riemannian Neural Gas (manifold operations)

  • Median LVQ (combinatorial M-step)

Fuzzy clustering models use lax.scan by default and fall back to a Python loop when callbacks, patience, or restore_best are enabled.

Mini-Batch Training

For large datasets, supervised models support mini-batch training. Each iteration samples a random batch from the training data:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    batch_size=64,
)
model.fit(X_large, y_large)

Online / Incremental Learning

Train on streaming data or update an existing model with new batches using partial_fit(). The model must be fitted first via fit(), then each partial_fit() call performs a single gradient update while preserving optimizer state across calls.

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)

# Initial training
model.fit(X_train, y_train)

# Incremental updates with new data batches
for X_batch, y_batch in data_stream:
    model.partial_fit(X_batch, y_batch)

# Prototypes and optimizer state are preserved across calls
preds = model.predict(X_test)

All supervised models support partial_fit().

Sample Weighting and Class Balancing

Per-Sample Weights

Pass sample_weight to fit() to assign different importance to individual training samples. Samples with higher weights have more influence on the loss:

import jax.numpy as jnp
from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)

# Per-sample weights (e.g., higher weight for difficult samples)
weights = jnp.array([1.0, 1.0, 2.0, 1.0, 3.0, ...])
model.fit(X_train, y_train, sample_weight=weights)

Class Balancing

For imbalanced datasets, use class_weight to automatically compute per-sample weights inversely proportional to class frequency:

from prosemble.models import GLVQ

# Automatic balanced weighting
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    class_weight='balanced',
)
model.fit(X_train, y_train)

# Or specify weights per class manually
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    class_weight={0: 1.0, 1: 5.0, 2: 1.0},
)
model.fit(X_train, y_train)

With class_weight='balanced', the weight for class c is computed as:

\[w_c = \frac{N}{K \cdot N_c}\]

where N is the total number of samples, K is the number of classes, and N_c is the number of samples in class c.

Validation and Early Stopping

Validation Monitoring

Pass validation data to fit() to monitor performance on held-out samples during training. Combined with restore_best=True, the model automatically restores the parameters that achieved the lowest validation loss:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    restore_best=True,
    use_scan=False,      # required for validation monitoring
)
model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
)

# model.prototypes_ now holds the best params from training
print(f"Best validation loss: {model.best_loss_}")

Patience-Based Early Stopping

Stop training early when no improvement is observed for a given number of consecutive iterations:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=500,
    lr=0.01,
    patience=20,
    restore_best=True,
    use_scan=False,
)
model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
)

print(f"Stopped at iteration {model.n_iter_}")

Fuzzy clustering models also support patience-based early stopping:

from prosemble.models import FCM

model = FCM(
    n_clusters=3,
    max_iter=500,
    patience=10,
    restore_best=True,
)
model.fit(X)

print(f"Stopped at iteration {model.n_iter_}")

Optimizer Configuration

Prosemble supports 26 optimizers via optax. Pass a string name or a pre-built optax.GradientTransformation to the optimizer parameter:

from prosemble.models import GLVQ

# String name (26 options)
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    optimizer='adamw',
)
model.fit(X_train, y_train)

Available optimizer strings: adam, adamw, adamax, adamaxw, adan, adabelief, amsgrad, radam, lamb, lion, novograd, sgd, sign_sgd, signum, noisy_sgd, lars, rmsprop, adagrad, adadelta, adafactor, sm3, yogi, rprop, fromage, lbfgs, dpsgd.

You can also pass a custom optax optimizer directly:

import optax

custom_opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=0.001),
)

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    optimizer=custom_opt,
)
model.fit(X_train, y_train)

Learning Rate Scheduling

Decay or warm up the learning rate during training using the lr_scheduler parameter. Nine built-in schedules are available:

from prosemble.models import GLVQ

# Cosine decay: lr decays smoothly from 0.01 to 0 over training
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lr_scheduler='cosine_decay',
)
model.fit(X_train, y_train)

# Exponential decay with custom decay rate
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lr_scheduler='exponential_decay',
    lr_scheduler_kwargs={'decay_rate': 0.95, 'transition_steps': 1},
)
model.fit(X_train, y_train)

# Warmup then cosine decay
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lr_scheduler='warmup_cosine_decay',
    lr_scheduler_kwargs={
        'warmup_steps': 20,
        'peak_value': 0.01,
        'end_value': 0.0,
    },
)
model.fit(X_train, y_train)

Available schedules:

Schedule

Description

exponential_decay

Exponential decay: lr * decay_rate^(step / transition_steps)

cosine_decay

Cosine annealing from lr to 0

warmup_cosine_decay

Linear warmup then cosine decay

warmup_exponential_decay

Linear warmup then exponential decay

warmup_constant

Linear warmup then constant learning rate

polynomial

Polynomial decay with configurable power

linear

Linear decay from lr to end_value

piecewise_constant

Step-wise constant schedule with boundaries

sgdr

Cosine annealing with warm restarts (SGDR)

Lookahead Optimizer

Lookahead maintains two sets of weights — “fast” weights updated every step, and “slow” weights updated by interpolating toward the fast weights every k steps. This improves generalization and reduces variance:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    lookahead={
        'sync_period': 6,         # sync slow weights every 6 steps
        'slow_step_size': 0.5,    # interpolation factor
    },
    use_scan=False,               # required for lookahead
)
model.fit(X_train, y_train)

After training, the slow weights (which generalize better) are used for inference.

Gradient Accumulation

Accumulate gradients over multiple steps before applying an update, effectively increasing the batch size without increasing memory usage:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    batch_size=16,
    gradient_accumulation_steps=4,  # effective batch size = 16 * 4 = 64
)
model.fit(X_train, y_train)

Parameter Freezing

Freeze specific parameters during training. Frozen parameters receive zero gradients and remain at their initial values:

from prosemble.models import GMLVQ

model = GMLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    freeze_params=['omega'],       # freeze the metric matrix
)
model.fit(X_train, y_train)

# Only prototypes were updated; omega stayed at its initial value

Common use cases:

  • Freeze 'omega' in GMLVQ to train prototypes with a fixed metric

  • Freeze 'prototypes' to learn only the metric adaptation matrix

  • Two-phase training: first train prototypes, then freeze them and train omega

Exponential Moving Average

Maintain an exponential moving average of parameters during training. EMA parameters often generalize better than the final training parameters:

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    ema_decay=0.999,
    use_scan=False,
)
model.fit(X_train, y_train)

The EMA update rule at each step is:

\[\theta_{\text{ema}} = \alpha \cdot \theta_{\text{ema}} + (1 - \alpha) \cdot \theta\]

where \(\alpha\) is the ema_decay factor.

Mixed Precision

Supervised models support built-in mixed precision training via the mixed_precision parameter. Master weights stay in float32; the forward/backward pass runs in lower precision for faster computation and lower memory on GPU.

import jax.numpy as jnp
from prosemble.models import GLVQ

# bfloat16 — recommended for training stability
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    mixed_precision='bfloat16',
)
model.fit(X_train, y_train)

# float16 — maximum speed, uses loss scaling to prevent underflow
model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    mixed_precision='float16',
)
model.fit(X_train, y_train)

# Prototypes remain in float32
assert model.prototypes_.dtype == jnp.float32

Checkpointing and Resume

Save and load models for persistence. All fitted parameters, optimizer state, and hyperparameters are preserved:

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Save
model.save('my_model.npz')

# Load
loaded = GLVQ.load('my_model.npz')

# Resume training from checkpoint
loaded.fit(X_train, y_train, resume=True)

Model Quantization

Reduce model size for deployment by quantizing prototypes and parameters when saving. The model in memory is unchanged; only the saved file is quantized:

model.fit(X_train, y_train)

# Quantize to int8 on save (smallest file size)
model.save('model_int8.npz', quantize='int8')

# Quantize to float16 on save (balanced)
model.save('model_f16.npz', quantize='float16')

# Load quantized model
loaded = GLVQ.load('model_int8.npz')

Export for Deployment

Export a JIT-compiled prediction function using jax.export for deployment without the full model or prosemble dependency:

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Export compiled predict function for batch size 32
exported = model.export_predict(batch_size=32)

# Serialize to bytes
blob = exported.serialize()

# Later, in a separate process without prosemble:
import jax
loaded = jax.export.deserialize(blob)
preds = loaded.call(X_batch)

The exported function contains only the compiled XLA computation and the frozen prototype/metric parameters — no Python dependencies needed at inference time.

ONNX Export

Export fitted models to ONNX for deployment without JAX or prosemble. 68 of 71 models are supported, including encoder models (MLP/CNN backbones), one-class classifiers, and fuzzy clustering:

from prosemble.core.onnx_export import export_onnx

model.fit(X_train, y_train)
onnx_model = export_onnx(model, path='model.onnx')

# Run with ONNX Runtime
import onnxruntime as ort
session = ort.InferenceSession('model.onnx')
preds = session.run(None, {'X': X_test_np})[0]

See the full guide: ONNX Export.

Prototype Analysis

Prototype Win Ratios

Analyze how often each prototype wins on correctly classified samples. This helps identify “dead” prototypes that never win and may be candidates for removal or reinitialization:

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=3, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

ratios = model.prototype_win_ratios(X_train, y_train)

for i, r in enumerate(ratios):
    label = model.prototype_labels_[i]
    print(f"Prototype {i} (class {label}): win ratio {r:.3f}")

A win ratio of 0.0 means the prototype never won on any correctly classified sample and is effectively unused.

GPU Acceleration

JAX automatically uses available GPUs. Check your device:

import jax
print(jax.devices())        # [GpuDevice(id=0)]
print(jax.default_backend())  # 'gpu' or 'cpu'

All prosemble operations (training, inference, distance computation) run on whatever device JAX is configured to use. No code changes needed.

Multi-Device Data Parallelism

For large datasets or image models with trainable backbones, prosemble supports data-parallel training across multiple GPUs or TPUs using JAX’s modern sharding API.

import jax
from prosemble.models import GLVQ

# Use all available devices
devices = jax.devices()  # e.g., [GpuDevice(id=0), GpuDevice(id=1), ...]

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    devices=devices,
    use_scan=False,
)
model.fit(X_train, y_train)

When devices is specified:

  • Training data is sharded along the batch dimension across devices

  • Model parameters and optimizer state are replicated on all devices

  • Gradients are automatically aggregated (all-reduced) across devices

  • After training, parameters are moved back to a single device for inference

Requirements:

  • batch_size (if set) must be divisible by the number of devices

  • The number of training samples should be divisible by the number of devices

  • use_scan=False is recommended for multi-device training

Primary use case: Image variants (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC) with trainable CNN backbones on large datasets across multiple GPUs/TPUs.

from prosemble.models import ImageGLVQ

model = ImageGLVQ(
    n_prototypes_per_class=1,
    max_iter=50,
    lr=0.001,
    devices=jax.devices(),
    gradient_checkpointing=True,  # save memory for deep backbones
    use_scan=False,
)
model.fit(X_images, y_labels)

partial_fit() also works with multi-device models:

model.fit(X_batch1, y_batch1)
model.partial_fit(X_batch2, y_batch2)  # data is automatically sharded

Gradient Checkpointing

Gradient checkpointing (jax.remat) trades compute for memory by recomputing forward activations during the backward pass instead of storing them. This is beneficial when training models with deep backbones (Image, Siamese variants) that would otherwise exhaust GPU memory.

from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    gradient_checkpointing=True,
)
model.fit(X_train, y_train)

When gradient_checkpointing=True:

  • Forward activations are not stored during the forward pass

  • During backpropagation, the forward pass is recomputed to obtain activations needed for gradient computation

  • Memory usage is reduced at the cost of ~33% more compute

  • Results are numerically identical to standard training

This option has no effect on models with shallow loss functions (e.g., standard GLVQ with Euclidean distance). It provides significant memory savings when unfreezing deep CNN backbones in Image or Siamese variants.

Custom Gradient Rules

Prosemble provides infrastructure for subclasses to define custom backward passes using jax.custom_vjp. This is useful for:

  • Optimal transport distances with non-differentiable sorting operations

  • Hard thresholding operations that need surrogate gradients

  • Distances with numerically unstable standard autodiff

Subclasses opt in by setting self._use_custom_vjp = True in their __init__ and overriding the _custom_vjp_loss method:

class MyModel(GLVQ):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._use_custom_vjp = True

    def _custom_vjp_loss(self, params, X, y, proto_labels):
        # Implement custom forward + backward pass
        ...

When _use_custom_vjp is active, gradient checkpointing is automatically disabled (custom VJP already controls what gets saved/recomputed).

Training Callbacks

Monitor training with callbacks. Training history is automatically tracked for all models:

model.fit(X_train, y_train)

# Supervised models
print(model.loss_history_)

# Clustering models
print(model.objective_history_)

Live Visualization

Enable real-time training visualization for clustering models:

from prosemble.models import FCM

model = FCM(
    n_clusters=3,
    max_iter=100,
    plot_steps=True,
)
model.fit(X)

Custom LVQ Optimizers

Prosemble provides three custom optax-compatible optimizers designed specifically for the geometry and parameter structure of LVQ models. These can be composed with standard optax transformations via optax.chain().

Per-Group Gradient Clipping

Different parameter types (prototypes, omega matrices, relevances) have different natural scales. A single global clip either under-constrains large parameters or over-constrains small ones. Per-group clipping clips each parameter group independently:

import optax
from prosemble.core.optimizers import per_group_clip

optimizer = optax.chain(
    per_group_clip({'prototypes': 1.0, 'omega': 0.5, 'sigmas': 0.1}),
    optax.adam(0.01),
)

from prosemble.models import GMLVQ
model = GMLVQ(
    n_prototypes_per_class=2,
    max_iter=100,
    lr=0.01,
    optimizer=optimizer,
)
model.fit(X_train, y_train)

Hypergradient Descent

Adaptive per-parameter learning rates via gradient correlation (Baydin et al. 2017). If consecutive gradients point in the same direction, the learning rate increases; if they oscillate, it decreases:

\[\eta_k^{t+1} = \text{clip}\left( \eta_k^t + \beta \cdot \langle g_k^t, g_k^{t-1} \rangle \right)\]
from prosemble.core.optimizers import hypergradient_descent
from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    optimizer=hypergradient_descent(
        init_lr=0.01,
        hyper_lr=1e-4,
        min_lr=1e-6,
        max_lr=0.1,
    ),
)
model.fit(X_train, y_train)

Riemannian Nesterov Momentum

Nesterov accelerated gradient providing \(O(1/t^2)\) convergence rate versus \(O(1/t)\) for vanilla gradient descent. Designed for use with Riemannian models where manifold projection is handled by _post_update():

from prosemble.core.optimizers import riemannian_nesterov
from prosemble.models import GLVQ

model = GLVQ(
    n_prototypes_per_class=2,
    max_iter=200,
    lr=0.01,
    optimizer=riemannian_nesterov(
        learning_rate=0.01,
        momentum=0.9,
    ),
)
model.fit(X_train, y_train)

Reject Option

The RejectOptionMixin adds calibrated rejection to any supervised prototype classifier. When the model is uncertain about a prediction (the sample lies near the decision boundary), it abstains instead of guessing.

The confidence measure is the GLVQ relative margin:

\[\text{confidence}(x) = \frac{d^-(x) - d^+(x)}{d^-(x) + d^+(x)}\]

Values near 0 indicate the decision boundary; values near 1 indicate high confidence.

from prosemble.models import GLVQ

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Predict with rejection (uncertain samples get label -1)
labels = model.predict_with_rejection(X_test, threshold=0.1)
rejected = labels == -1
print(f"Rejected {rejected.sum()} of {len(labels)} samples")

# Compute confidence scores
conf = model.confidence(X_test)

# Find the optimal threshold minimizing Chow's risk
threshold = model.optimal_threshold(
    X_val, y_val,
    cost_reject=0.5,    # rejection costs half of a misclassification
    cost_error=1.0,
)

# Accuracy-coverage curve
thresholds, accuracies, coverages = model.accuracy_coverage_curve(X_val, y_val)

Curriculum Learning

Self-paced curriculum learning adaptively weights training samples by difficulty. Early in training, the model focuses on easy samples; as training progresses, harder samples are gradually introduced.

The difficulty measure is the per-sample GLVQ loss (mu-ratio). Samples with loss below the current threshold are weighted; those above are down-weighted or excluded.

from prosemble.core.curriculum import (
    curriculum_weights, curriculum_threshold, apply_curriculum_to_loss,
)

# Compute per-sample difficulty weights
threshold = curriculum_threshold(
    iteration=50,
    max_iter=200,
    init_threshold=0.3,   # initially only easy samples
    final_threshold=1.0,  # eventually all samples
    schedule='linear',
)
weights = curriculum_weights(per_sample_loss, threshold, mode='soft')

# Or use the combined pipeline
weighted_loss = apply_curriculum_to_loss(
    per_sample_losses,
    iteration=50,
    max_iter=200,
    init_threshold=0.3,
    final_threshold=1.0,
)

Three weighting modes are available:

  • 'hard': binary — include (weight=1) or exclude (weight=0)

  • 'soft': smooth sigmoid transition at the threshold

  • 'linear': linearly decrease weight from 1 to 0 as loss increases

Three scheduling strategies:

  • 'linear': threshold increases linearly with iteration

  • 'exponential': slow start, fast growth

  • 'cosine': smooth cosine schedule from init to final threshold

Prototype Regularization

Prototype Diversity (DPP)

Prevent same-class prototypes from collapsing onto each other using a determinantal point process (DPP) inspired penalty. The log-determinant of the distance kernel matrix encourages spread:

from prosemble.core.regularization import prototype_diversity_loss

# Add as a regularization term to the loss
div_loss = prototype_diversity_loss(
    prototypes, proto_labels,
    sigma_div=1.0,
)
total_loss = classification_loss + 0.01 * div_loss

Sparse Relevance Regularization

For relevance learning models (GRLVQ), encourage sparsity in the learned relevance profile via proximal gradient operators:

from prosemble.core.regularization import (
    sparse_relevance_proximal,
    elastic_net_proximal,
)

# L1 soft-thresholding (promotes exact zeros)
sparse_rel = sparse_relevance_proximal(
    relevances, l1_weight=0.01, lr=0.01,
)

# Elastic net (L1 + L2 combination)
sparse_rel = elastic_net_proximal(
    relevances, l1_weight=0.01, l2_weight=0.001, lr=0.01,
)

Geodesic Interpolation

For Riemannian models, geodesic utilities enable visualization and analysis of decision boundaries on curved manifolds.

from prosemble.core.geodesic import (
    geodesic_interpolation,
    geodesic_midpoint,
    decision_boundary_point,
    prototype_geodesic_distances,
    inter_class_geodesics,
)
from prosemble.core.manifolds import SO

manifold = SO(3)

# Compute geodesic path between two SO(3) prototypes
path = geodesic_interpolation(manifold, proto_a, proto_b, n_points=50)

# Find the decision boundary along the geodesic
boundary_pt, t_boundary = decision_boundary_point(
    manifold, proto_a, proto_b,
)
print(f"Boundary at t={t_boundary:.3f}")  # t=0.5 for symmetric manifolds

# Pairwise geodesic distance matrix between all prototypes
dist_matrix = prototype_geodesic_distances(
    manifold, prototypes, proto_labels,
)

# All inter-class geodesics with boundary locations
geodesics = inter_class_geodesics(
    manifold, prototypes, proto_labels,
)