"""
Siamese LVQ models: SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ.
In Siamese variants, an MLP backbone transforms BOTH inputs AND
prototypes before computing distances. Prototypes remain in the
original input space and are projected at each step.
This contrasts with LVQMLN, where only inputs are transformed and
prototypes live directly in latent space.
Architecture::
Input (d) ---> backbone ---> latent_x
|
v
Prototype (d) -> backbone -> latent_w distance(latent_x, latent_w)
|
v
LVQ loss
References
----------
.. [1] Villmann, T., et al. (2017). Prototype-based Neural Network
Layers: Incorporating Vector Quantization. arXiv:1812.01214.
"""
import jax
import jax.numpy as jnp
import numpy as np
from prosemble.models.prototype_base import SupervisedPrototypeModel
from prosemble.core.losses import glvq_loss_with_transfer
from prosemble.core.distance import squared_euclidean_distance_matrix
from prosemble.core.competitions import wtac
from prosemble.core.initializers import identity_omega_init, random_omega_init
from prosemble.models.lvqmln import _mlp_init, _mlp_forward
from prosemble.core.utils import orthogonalize
[docs]
class SiameseGLVQ(SupervisedPrototypeModel):
"""Siamese GLVQ — GLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP
backbone before computing squared Euclidean distances.
Parameters
----------
hidden_sizes : list of int
Hidden layer sizes for the backbone MLP.
latent_dim : int
Dimension of the embedding space.
activation : str
Activation function for the backbone MLP. Supported values:
'sigmoid', 'relu', 'tanh', 'leaky_relu', 'selu'.
beta : float
Transfer function parameter for GLVQ loss.
bb_lr : float, optional
Separate learning rate for the backbone network. If None,
uses the same lr as prototypes. Default: None.
both_path_gradients : bool
If True, compute gradients through both input and prototype
paths. If False, prototype path gradients are stopped.
Default: True.
n_prototypes_per_class : int
Number of prototypes per class.
max_iter : int
Maximum training iterations.
lr : float
Learning rate.
epsilon : float
Convergence threshold on loss change.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Distance function (default: squared Euclidean).
optimizer : str or optax optimizer, optional
Optimizer name ('adam', 'sgd') or optax GradientTransformation.
Default: 'adam'.
transfer_fn : callable, optional
Transfer function for loss shaping (default: identity).
margin : float
Margin for loss computation.
callbacks : list, optional
List of Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size : int, optional
Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler : str or optax.Schedule, optional
Learning rate schedule. Supported strings: 'exponential_decay',
'cosine_decay', 'warmup_cosine_decay', 'warmup_exponential_decay',
'warmup_constant', 'polynomial', 'linear', 'piecewise_constant',
'sgdr'. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs : dict, optional
Keyword arguments passed to the learning rate scheduler
(e.g. ``decay_rate``, ``transition_steps``). Default: None.
prototypes_initializer : str or callable, optional
How to initialize prototypes. Supported strings: 'stratified_random'
(default), 'class_mean', 'class_conditional_mean', 'stratified_noise',
'random_normal', 'uniform', 'zeros', 'ones', 'fill_value'.
Or pass a callable ``(X, y, n_per_class, key) -> (protos, labels)``.
patience : int, optional
Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best : bool
If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight : dict or 'balanced', optional
Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. 'balanced' auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps : int, optional
Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay : float, optional
Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params : list of str, optional
List of parameter group names to freeze (zero gradients).
E.g. ['backbone'] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead : dict, optional
Enable lookahead optimizer wrapper. Dict with keys:
- 'sync_period': int (default 6) — sync every k steps
- 'slow_step_size': float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision : str or None, optional
Compute dtype for mixed precision training. 'float16' or 'bfloat16'.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
"""
def __init__(self, hidden_sizes=None, latent_dim=2,
activation='sigmoid', beta=10.0, bb_lr=None,
both_path_gradients=True,
n_prototypes_per_class=1, max_iter=100, lr=0.01,
epsilon=1e-6, random_seed=42, distance_fn=None,
optimizer='adam', transfer_fn=None, margin=0.0,
callbacks=None, use_scan=True, batch_size=None,
lr_scheduler=None, lr_scheduler_kwargs=None,
prototypes_initializer=None, patience=None,
restore_best=False, class_weight=None,
gradient_accumulation_steps=None, ema_decay=None,
freeze_params=None, lookahead=None,
mixed_precision=None):
super().__init__(
n_prototypes_per_class=n_prototypes_per_class,
max_iter=max_iter, lr=lr, epsilon=epsilon,
random_seed=random_seed, distance_fn=distance_fn,
optimizer=optimizer, transfer_fn=transfer_fn, margin=margin,
callbacks=callbacks, use_scan=use_scan, batch_size=batch_size,
lr_scheduler=lr_scheduler,
lr_scheduler_kwargs=lr_scheduler_kwargs,
prototypes_initializer=prototypes_initializer,
patience=patience, restore_best=restore_best,
class_weight=class_weight,
gradient_accumulation_steps=gradient_accumulation_steps,
ema_decay=ema_decay, freeze_params=freeze_params,
lookahead=lookahead, mixed_precision=mixed_precision,
)
self.hidden_sizes = hidden_sizes or [10]
self.latent_dim = latent_dim
self.activation = activation
self.beta = beta
self.bb_lr = bb_lr
self.both_path_gradients = both_path_gradients
self.backbone_params_ = None
if bb_lr is not None:
self._optimizer = self._build_multi_lr_optimizer(
self._optimizer_spec, self.lr, bb_lr
)
def _build_multi_lr_optimizer(self, optimizer, proto_lr, bb_lr):
"""Build optimizer with separate learning rates for prototypes and backbone."""
import optax
if not isinstance(optimizer, str):
return optimizer
proto_opt = self._build_optimizer(optimizer, proto_lr)
bb_opt = self._build_optimizer(optimizer, bb_lr)
return optax.multi_transform(
{'prototypes': proto_opt, 'backbone': bb_opt},
param_labels=lambda params: {k: k for k in params},
)
def _get_resume_params(self, params):
params['backbone'] = self.backbone_params_
return params
def _init_state(self, X, y, key):
n_features = X.shape[1]
layer_sizes = [n_features] + list(self.hidden_sizes) + [self.latent_dim]
key1, key2 = jax.random.split(key)
backbone_params = _mlp_init(key1, layer_sizes, self.activation)
# Prototypes in input space (they get projected through backbone)
prototypes, proto_labels = self._init_prototypes(
X, y, self.n_prototypes_per_class, key2
)
params = {
'prototypes': prototypes,
'backbone': backbone_params,
}
opt_state = self._optimizer.init(params)
from prosemble.models.prototype_base import SupervisedState
state = SupervisedState(
prototypes=prototypes,
opt_state=opt_state,
loss=jnp.array(float('inf')),
iteration=0,
converged=False,
)
return state, params, proto_labels
def _compute_loss(self, params, X, y, proto_labels):
backbone = params['backbone']
latent_x = _mlp_forward(backbone, X, self.activation)
latent_w = _mlp_forward(backbone, params['prototypes'], self.activation)
if not self.both_path_gradients:
latent_w = jax.lax.stop_gradient(latent_w)
distances = squared_euclidean_distance_matrix(latent_x, latent_w)
return glvq_loss_with_transfer(
distances, y, proto_labels,
transfer_fn=self.transfer_fn,
margin=self.margin,
beta=self.beta,
)
def _extract_results(self, params, proto_labels, loss_history, n_iter, **kwargs):
super()._extract_results(params, proto_labels, loss_history, n_iter, **kwargs)
self.backbone_params_ = params['backbone']
[docs]
def predict(self, X):
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
latent_x = _mlp_forward(self.backbone_params_, X, self.activation)
latent_w = _mlp_forward(self.backbone_params_, self.prototypes_, self.activation)
distances = squared_euclidean_distance_matrix(latent_x, latent_w)
return wtac(distances, self.prototype_labels_)
[docs]
def predict_proba(self, X):
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
latent_x = _mlp_forward(self.backbone_params_, X, self.activation)
latent_w = _mlp_forward(self.backbone_params_, self.prototypes_, self.activation)
distances = squared_euclidean_distance_matrix(latent_x, latent_w)
from prosemble.core.pooling import stratified_min_pooling
class_dists = stratified_min_pooling(
distances, self.prototype_labels_, self.n_classes_
)
return jax.nn.softmax(-class_dists, axis=1)
def transform(self, X):
"""Transform data through the backbone."""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
return _mlp_forward(self.backbone_params_, X, self.activation)
def _check_fitted(self):
if self.prototypes_ is None or self.backbone_params_ is None:
from prosemble.models.base import NotFittedError
raise NotFittedError("Model not fitted. Call fit() first.")
def _get_hyperparams(self):
hp = super()._get_hyperparams()
hp.update({
'hidden_sizes': self.hidden_sizes,
'latent_dim': self.latent_dim,
'activation': self.activation,
'beta': self.beta,
})
return hp
[docs]
class SiameseGMLVQ(SupervisedPrototypeModel):
"""Siamese GMLVQ — GMLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP,
then distances are computed using a learned :math:`\\Omega` matrix in the
latent space:
.. math::
d = \\|\\Omega(f(x) - f(w))\\|^2
Parameters
----------
hidden_sizes : list of int
Hidden layer sizes for the backbone MLP.
latent_dim : int
Dimension of the backbone output (embedding space).
omega_dim : int, optional
Omega mapping dimension (number of rows in Omega). If None,
uses latent_dim (square matrix). Default: None.
activation : str
Activation function for the backbone MLP. Supported values:
'sigmoid', 'relu', 'tanh', 'leaky_relu', 'selu'.
beta : float
Transfer function parameter for GLVQ loss.
bb_lr : float, optional
Separate learning rate for the backbone network. If None,
uses the same lr as prototypes. Default: None.
both_path_gradients : bool
If True, compute gradients through both input and prototype
paths. If False, prototype path gradients are stopped.
Default: True.
n_prototypes_per_class : int
Number of prototypes per class.
max_iter : int
Maximum training iterations.
lr : float
Learning rate.
epsilon : float
Convergence threshold on loss change.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Distance function (default: squared Euclidean).
optimizer : str or optax optimizer, optional
Optimizer name ('adam', 'sgd') or optax GradientTransformation.
Default: 'adam'.
transfer_fn : callable, optional
Transfer function for loss shaping (default: identity).
margin : float
Margin for loss computation.
callbacks : list, optional
List of Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size : int, optional
Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler : str or optax.Schedule, optional
Learning rate schedule. Supported strings: 'exponential_decay',
'cosine_decay', 'warmup_cosine_decay', 'warmup_exponential_decay',
'warmup_constant', 'polynomial', 'linear', 'piecewise_constant',
'sgdr'. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs : dict, optional
Keyword arguments passed to the learning rate scheduler
(e.g. ``decay_rate``, ``transition_steps``). Default: None.
prototypes_initializer : str or callable, optional
How to initialize prototypes. Supported strings: 'stratified_random'
(default), 'class_mean', 'class_conditional_mean', 'stratified_noise',
'random_normal', 'uniform', 'zeros', 'ones', 'fill_value'.
Or pass a callable ``(X, y, n_per_class, key) -> (protos, labels)``.
patience : int, optional
Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best : bool
If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight : dict or 'balanced', optional
Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. 'balanced' auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps : int, optional
Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay : float, optional
Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params : list of str, optional
List of parameter group names to freeze (zero gradients).
E.g. ['backbone'] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead : dict, optional
Enable lookahead optimizer wrapper. Dict with keys:
- 'sync_period': int (default 6) — sync every k steps
- 'slow_step_size': float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision : str or None, optional
Compute dtype for mixed precision training. 'float16' or 'bfloat16'.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
"""
def __init__(self, hidden_sizes=None, latent_dim=2,
omega_dim=None, activation='sigmoid', beta=10.0,
bb_lr=None, both_path_gradients=True,
n_prototypes_per_class=1, max_iter=100, lr=0.01,
epsilon=1e-6, random_seed=42, distance_fn=None,
optimizer='adam', transfer_fn=None, margin=0.0,
callbacks=None, use_scan=True, batch_size=None,
lr_scheduler=None, lr_scheduler_kwargs=None,
prototypes_initializer=None, patience=None,
restore_best=False, class_weight=None,
gradient_accumulation_steps=None, ema_decay=None,
freeze_params=None, lookahead=None,
mixed_precision=None):
super().__init__(
n_prototypes_per_class=n_prototypes_per_class,
max_iter=max_iter, lr=lr, epsilon=epsilon,
random_seed=random_seed, distance_fn=distance_fn,
optimizer=optimizer, transfer_fn=transfer_fn, margin=margin,
callbacks=callbacks, use_scan=use_scan, batch_size=batch_size,
lr_scheduler=lr_scheduler,
lr_scheduler_kwargs=lr_scheduler_kwargs,
prototypes_initializer=prototypes_initializer,
patience=patience, restore_best=restore_best,
class_weight=class_weight,
gradient_accumulation_steps=gradient_accumulation_steps,
ema_decay=ema_decay, freeze_params=freeze_params,
lookahead=lookahead, mixed_precision=mixed_precision,
)
self.hidden_sizes = hidden_sizes or [10]
self.latent_dim = latent_dim
self.omega_dim = omega_dim
self.activation = activation
self.beta = beta
self.bb_lr = bb_lr
self.both_path_gradients = both_path_gradients
self.backbone_params_ = None
self.omega_ = None
if bb_lr is not None:
self._optimizer = self._build_multi_lr_optimizer(
self._optimizer_spec, self.lr, bb_lr
)
def _build_multi_lr_optimizer(self, optimizer, proto_lr, bb_lr):
"""Build optimizer with separate learning rates for prototypes and backbone."""
import optax
if not isinstance(optimizer, str):
return optimizer
proto_opt = self._build_optimizer(optimizer, proto_lr)
bb_opt = self._build_optimizer(optimizer, bb_lr)
return optax.multi_transform(
{'prototypes': proto_opt, 'backbone': bb_opt, 'omega': proto_opt},
param_labels=lambda params: {k: k for k in params},
)
def _get_resume_params(self, params):
params['backbone'] = self.backbone_params_
params['omega'] = self.omega_
return params
def _init_state(self, X, y, key):
n_features = X.shape[1]
layer_sizes = [n_features] + list(self.hidden_sizes) + [self.latent_dim]
omega_dim = self.omega_dim or self.latent_dim
key1, key2 = jax.random.split(key)
backbone_params = _mlp_init(key1, layer_sizes, self.activation)
prototypes, proto_labels = self._init_prototypes(
X, y, self.n_prototypes_per_class, key2
)
omega = identity_omega_init(self.latent_dim, omega_dim)
params = {
'prototypes': prototypes,
'backbone': backbone_params,
'omega': omega,
}
opt_state = self._optimizer.init(params)
from prosemble.models.prototype_base import SupervisedState
state = SupervisedState(
prototypes=prototypes,
opt_state=opt_state,
loss=jnp.array(float('inf')),
iteration=0,
converged=False,
)
return state, params, proto_labels
def _compute_loss(self, params, X, y, proto_labels):
backbone = params['backbone']
omega = params['omega']
latent_x = _mlp_forward(backbone, X, self.activation)
latent_w = _mlp_forward(backbone, params['prototypes'], self.activation)
if not self.both_path_gradients:
latent_w = jax.lax.stop_gradient(latent_w)
# Omega distance in latent space
diff = latent_x[:, None, :] - latent_w[None, :, :]
projected = jnp.einsum('npd,dl->npl', diff, omega)
distances = jnp.sum(projected ** 2, axis=2)
return glvq_loss_with_transfer(
distances, y, proto_labels,
transfer_fn=self.transfer_fn,
margin=self.margin,
beta=self.beta,
)
def _extract_results(self, params, proto_labels, loss_history, n_iter, **kwargs):
super()._extract_results(params, proto_labels, loss_history, n_iter, **kwargs)
self.backbone_params_ = params['backbone']
self.omega_ = params['omega']
[docs]
def predict(self, X):
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
latent_x = _mlp_forward(self.backbone_params_, X, self.activation)
latent_w = _mlp_forward(self.backbone_params_, self.prototypes_, self.activation)
diff = latent_x[:, None, :] - latent_w[None, :, :]
projected = jnp.einsum('npd,dl->npl', diff, self.omega_)
distances = jnp.sum(projected ** 2, axis=2)
return wtac(distances, self.prototype_labels_)
def transform(self, X):
"""Transform data through the backbone."""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
return _mlp_forward(self.backbone_params_, X, self.activation)
@property
def lambda_matrix(self):
"""Return Lambda = Omega^T Omega in latent space."""
if self.omega_ is None:
raise ValueError("Model not fitted.")
return self.omega_.T @ self.omega_
def _check_fitted(self):
if self.prototypes_ is None or self.backbone_params_ is None:
from prosemble.models.base import NotFittedError
raise NotFittedError("Model not fitted. Call fit() first.")
def _get_hyperparams(self):
hp = super()._get_hyperparams()
hp.update({
'hidden_sizes': self.hidden_sizes,
'latent_dim': self.latent_dim,
'activation': self.activation,
'beta': self.beta,
})
if self.omega_dim is not None:
hp['omega_dim'] = self.omega_dim
return hp
[docs]
class SiameseGTLVQ(SupervisedPrototypeModel):
"""Siamese GTLVQ — GTLVQ with a learned embedding network.
Both inputs and prototypes are transformed through the same MLP,
then tangent distances are computed in the latent space using
per-prototype subspace bases.
Parameters
----------
hidden_sizes : list of int
Hidden layer sizes for the backbone MLP.
latent_dim : int
Dimension of the backbone output (embedding space).
subspace_dim : int
Tangent subspace dimension per prototype. Each prototype gets a
learned orthonormal basis of this rank in latent space.
activation : str
Activation function for the backbone MLP. Supported values:
'sigmoid', 'relu', 'tanh', 'leaky_relu', 'selu'.
beta : float
Transfer function parameter for GLVQ loss.
bb_lr : float, optional
Separate learning rate for the backbone network. If None,
uses the same lr as prototypes. Default: None.
both_path_gradients : bool
If True, compute gradients through both input and prototype
paths. If False, prototype path gradients are stopped.
Default: True.
n_prototypes_per_class : int
Number of prototypes per class.
max_iter : int
Maximum training iterations.
lr : float
Learning rate.
epsilon : float
Convergence threshold on loss change.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Distance function (default: squared Euclidean).
optimizer : str or optax optimizer, optional
Optimizer name ('adam', 'sgd') or optax GradientTransformation.
Default: 'adam'.
transfer_fn : callable, optional
Transfer function for loss shaping (default: identity).
margin : float
Margin for loss computation.
callbacks : list, optional
List of Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster, JIT-compiled,
but runs all max_iter iterations even after convergence).
If False, use a Python for-loop with true early stopping (no wasted
compute after convergence, but slower per iteration).
batch_size : int, optional
Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this size.
lr_scheduler : str or optax.Schedule, optional
Learning rate schedule. Supported strings: 'exponential_decay',
'cosine_decay', 'warmup_cosine_decay', 'warmup_exponential_decay',
'warmup_constant', 'polynomial', 'linear', 'piecewise_constant',
'sgdr'. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs : dict, optional
Keyword arguments passed to the learning rate scheduler
(e.g. ``decay_rate``, ``transition_steps``). Default: None.
prototypes_initializer : str or callable, optional
How to initialize prototypes. Supported strings: 'stratified_random'
(default), 'class_mean', 'class_conditional_mean', 'stratified_noise',
'random_normal', 'uniform', 'zeros', 'ones', 'fill_value'.
Or pass a callable ``(X, y, n_per_class, key) -> (protos, labels)``.
patience : int, optional
Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step (epsilon
check). Requires use_scan=False for true early stopping.
restore_best : bool
If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default: False.
class_weight : dict or 'balanced', optional
Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. 'balanced' auto-computes weights
inversely proportional to class frequencies. Default: None (uniform).
gradient_accumulation_steps : int, optional
Accumulate gradients over this many steps before applying an update.
Effective batch size = batch_size * gradient_accumulation_steps.
Default: None (no accumulation).
ema_decay : float, optional
Exponential moving average decay for parameters (0 < ema_decay < 1).
After training, model parameters are replaced with EMA-smoothed values.
Typical values: 0.999, 0.9999. Default: None (no EMA).
freeze_params : list of str, optional
List of parameter group names to freeze (zero gradients).
E.g. ['backbone'] to freeze the backbone and only train prototypes.
Default: None (all parameters trainable).
lookahead : dict, optional
Enable lookahead optimizer wrapper. Dict with keys:
- 'sync_period': int (default 6) — sync every k steps
- 'slow_step_size': float (default 0.5) — interpolation factor
Default: None (no lookahead).
mixed_precision : str or None, optional
Compute dtype for mixed precision training. 'float16' or 'bfloat16'.
Master weights stay in float32; forward/backward pass runs in lower
precision for ~2x speed and ~half memory on GPU. Float16 uses static
loss scaling to prevent gradient underflow. Default: None (disabled).
"""
def __init__(self, hidden_sizes=None, latent_dim=4,
subspace_dim=2, activation='sigmoid', beta=10.0,
bb_lr=None, both_path_gradients=True,
n_prototypes_per_class=1, max_iter=100, lr=0.01,
epsilon=1e-6, random_seed=42, distance_fn=None,
optimizer='adam', transfer_fn=None, margin=0.0,
callbacks=None, use_scan=True, batch_size=None,
lr_scheduler=None, lr_scheduler_kwargs=None,
prototypes_initializer=None, patience=None,
restore_best=False, class_weight=None,
gradient_accumulation_steps=None, ema_decay=None,
freeze_params=None, lookahead=None,
mixed_precision=None):
super().__init__(
n_prototypes_per_class=n_prototypes_per_class,
max_iter=max_iter, lr=lr, epsilon=epsilon,
random_seed=random_seed, distance_fn=distance_fn,
optimizer=optimizer, transfer_fn=transfer_fn, margin=margin,
callbacks=callbacks, use_scan=use_scan, batch_size=batch_size,
lr_scheduler=lr_scheduler,
lr_scheduler_kwargs=lr_scheduler_kwargs,
prototypes_initializer=prototypes_initializer,
patience=patience, restore_best=restore_best,
class_weight=class_weight,
gradient_accumulation_steps=gradient_accumulation_steps,
ema_decay=ema_decay, freeze_params=freeze_params,
lookahead=lookahead, mixed_precision=mixed_precision,
)
self.hidden_sizes = hidden_sizes or [10]
self.latent_dim = latent_dim
self.subspace_dim = subspace_dim
self.activation = activation
self.beta = beta
self.bb_lr = bb_lr
self.both_path_gradients = both_path_gradients
self.backbone_params_ = None
self.omegas_ = None
if bb_lr is not None:
self._optimizer = self._build_multi_lr_optimizer(
self._optimizer_spec, self.lr, bb_lr
)
def _build_multi_lr_optimizer(self, optimizer, proto_lr, bb_lr):
"""Build optimizer with separate learning rates for prototypes and backbone."""
import optax
if not isinstance(optimizer, str):
return optimizer
proto_opt = self._build_optimizer(optimizer, proto_lr)
bb_opt = self._build_optimizer(optimizer, bb_lr)
return optax.multi_transform(
{'prototypes': proto_opt, 'backbone': bb_opt, 'omegas': proto_opt},
param_labels=lambda params: {k: k for k in params},
)
def _get_resume_params(self, params):
params['backbone'] = self.backbone_params_
params['omegas'] = self.omegas_
return params
def _init_state(self, X, y, key):
n_features = X.shape[1]
layer_sizes = [n_features] + list(self.hidden_sizes) + [self.latent_dim]
key1, key2, key3 = jax.random.split(key, 3)
backbone_params = _mlp_init(key1, layer_sizes, self.activation)
prototypes, proto_labels = self._init_prototypes(
X, y, self.n_prototypes_per_class, key2
)
n_protos = prototypes.shape[0]
# Per-prototype tangent bases in latent space
keys = jax.random.split(key3, n_protos)
omegas = jnp.stack([
random_omega_init(self.latent_dim, self.subspace_dim, k) for k in keys
]) # (p, latent_dim, subspace_dim)
params = {
'prototypes': prototypes,
'backbone': backbone_params,
'omegas': omegas,
}
opt_state = self._optimizer.init(params)
from prosemble.models.prototype_base import SupervisedState
state = SupervisedState(
prototypes=prototypes,
opt_state=opt_state,
loss=jnp.array(float('inf')),
iteration=0,
converged=False,
)
return state, params, proto_labels
def _compute_loss(self, params, X, y, proto_labels):
backbone = params['backbone']
omegas = params['omegas'] # (p, latent_dim, subspace_dim)
latent_x = _mlp_forward(backbone, X, self.activation)
latent_w = _mlp_forward(backbone, params['prototypes'], self.activation)
if not self.both_path_gradients:
latent_w = jax.lax.stop_gradient(latent_w)
# Tangent distance in latent space
diff = latent_x[:, None, :] - latent_w[None, :, :] # (n, p, latent_dim)
proj = jnp.einsum('npd,pds->nps', diff, omegas) # (n, p, s)
recon = jnp.einsum('nps,pds->npd', proj, omegas) # (n, p, latent_dim)
tang_diff = diff - recon
distances = jnp.sum(tang_diff ** 2, axis=2) # (n, p)
return glvq_loss_with_transfer(
distances, y, proto_labels,
transfer_fn=self.transfer_fn,
margin=self.margin,
beta=self.beta,
)
def _post_update(self, params):
"""Re-orthogonalize tangent bases."""
if 'omegas' not in params:
return params
omegas = jax.vmap(orthogonalize)(params['omegas'])
return {**params, 'omegas': omegas}
def _extract_results(self, params, proto_labels, loss_history, n_iter, **kwargs):
super()._extract_results(params, proto_labels, loss_history, n_iter, **kwargs)
self.backbone_params_ = params['backbone']
self.omegas_ = params['omegas']
[docs]
def predict(self, X):
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
latent_x = _mlp_forward(self.backbone_params_, X, self.activation)
latent_w = _mlp_forward(self.backbone_params_, self.prototypes_, self.activation)
diff = latent_x[:, None, :] - latent_w[None, :, :]
proj = jnp.einsum('npd,pds->nps', diff, self.omegas_)
recon = jnp.einsum('nps,pds->npd', proj, self.omegas_)
tang_diff = diff - recon
distances = jnp.sum(tang_diff ** 2, axis=2)
return wtac(distances, self.prototype_labels_)
def transform(self, X):
"""Transform data through the backbone."""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
return _mlp_forward(self.backbone_params_, X, self.activation)
def _check_fitted(self):
if self.prototypes_ is None or self.backbone_params_ is None:
from prosemble.models.base import NotFittedError
raise NotFittedError("Model not fitted. Call fit() first.")
def _get_hyperparams(self):
hp = super()._get_hyperparams()
hp.update({
'hidden_sizes': self.hidden_sizes,
'latent_dim': self.latent_dim,
'subspace_dim': self.subspace_dim,
'activation': self.activation,
'beta': self.beta,
})
return hp