"""
Supervised Vector Quantization One-Class Classification (SVQ-OCC).
Combines Neural Gas representation learning with one-class classification
using per-prototype visibility parameters. Each prototype is equipped with
a local visibility range :math:`\\theta_k` -- data within :math:`\\theta_k`
of prototype :math:`w_k` is classified as target.
The overall cost function is:
.. math::
E(X, W) = \\alpha \\cdot R(X^+, W) + (1 - \\alpha) \\cdot C(X, W, \\Theta)
where :math:`R` is the Neural Gas representation cost over target data
:math:`X^+`, and :math:`C` is a classification cost using per-prototype
responsibilities:
.. math::
r(x, w_k, \\gamma_k, \\theta_k) = p(w_k, \\gamma_k \\mid x) \\cdot
\\mathrm{sgd}_{\\sigma}(d(x, w_k), \\theta_k)
Three classification costs are available:
- Contrastive Score (CS):
:math:`1 - (TP \\cdot TN - FP \\cdot FN) / ((TP + FP)(TN + FN))`
- Brier Score (BS):
:math:`\\mathrm{mean}\\,(y - \\sum_k r)^2`
- Cross Entropy (CE):
:math:`-\\mathrm{mean}\\,[y \\cdot \\log(\\sum r) + (1 - y) \\cdot \\log(1 - \\sum r)]`
Three response probability models :math:`p(w_k \\mid x)` are supported:
- Gaussian: :math:`\\mathrm{softmax}(-\\gamma \\cdot d(x, w_k))`
- Student-t: :math:`(1 + d / \\nu)^{-(\\nu + 1)/2}`, normalized
- Uniform: :math:`1 / K`
References
----------
.. [1] Staps, C., Schubert, L., Kaden, M., Lampe, B., Hermann, W.,
& Villmann, T. (2022). Prototype-based One-Class-Classification
Learning Using Local Representations. IEEE WSOM+ 2022.
"""
import jax
import jax.numpy as jnp
import numpy as np
from prosemble.models.prototype_base import SupervisedPrototypeModel
from prosemble.core.initializers import stratified_selection_init
[docs]
class SVQOCC(SupervisedPrototypeModel):
"""Supervised Vector Quantization One-Class Classification.
Combines Neural Gas representation learning with per-prototype
visibility parameters :math:`\\theta_k` for one-class classification.
Parameters
----------
n_prototypes : int
Number of prototypes for the target class.
target_label : int, optional
Which label is the target (normal) class. Default: auto-detect
as the most frequent class.
alpha : float
Balance between representation (R) and classification (C) cost.
E = alpha * R + (1 - alpha) * C. Default: 0.5.
cost_function : str
Classification cost variant: 'contrastive', 'brier', 'cross_entropy'.
Default: 'contrastive'.
response_type : str
Response probability model: 'gaussian', 'student_t', 'uniform'.
Default: 'gaussian'.
sigma : float
Sigmoid sharpness for differentiable Heaviside approximation.
Smaller = sharper boundary. Default: 0.1.
gamma_resp : float
Response bandwidth for Gaussian probabilistic assignment. Default: 1.0.
nu : float
Degrees of freedom for Student-t response. Default: 1.0.
lambda_init : float, optional
Initial NG neighborhood range. Default: n_prototypes / 2.
lambda_final : float
Final NG neighborhood range. Default: 0.01.
lambda_decay : float, optional
Per-step multiplicative decay for lambda. Default: computed from
max_iter.
max_iter : int
Maximum training iterations.
lr : float
Learning rate.
epsilon : float
Convergence threshold on loss change.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Distance function (default: squared Euclidean).
optimizer : str or optax optimizer, optional
Optimizer name ('adam', 'sgd') or optax GradientTransformation.
Default: 'adam'.
transfer_fn : callable, optional
Transfer function for loss shaping (default: identity).
margin : float
Margin for loss computation.
callbacks : list, optional
List of Callback objects.
use_scan : bool
If True (default), use jax.lax.scan for training (faster,
JIT-compiled, but runs all max_iter iterations even after
convergence). If False, use a Python for-loop with true early
stopping (no wasted compute after convergence, but slower per
iteration).
batch_size : int, optional
Mini-batch size. If None (default), use full-batch training.
When set, each epoch iterates over shuffled mini-batches of this
size.
lr_scheduler : str or optax.Schedule, optional
Learning rate schedule. Supported strings: 'exponential_decay',
'cosine_decay', 'warmup_cosine_decay', 'warmup_exponential_decay',
'warmup_constant', 'polynomial', 'linear', 'piecewise_constant',
'sgdr'. Or pass a custom optax.Schedule. Default: None.
lr_scheduler_kwargs : dict, optional
Keyword arguments passed to the learning rate scheduler
(e.g. ``decay_rate``, ``transition_steps``). Default: None.
prototypes_initializer : str or callable, optional
How to initialize prototypes. Supported strings:
'stratified_random' (default), 'class_mean',
'class_conditional_mean', 'stratified_noise', 'random_normal',
'uniform', 'zeros', 'ones', 'fill_value'. Or pass a callable
``(X, y, n_per_class, key) -> (protos, labels)``.
patience : int, optional
Number of consecutive epochs with no improvement before stopping.
If None (default), stops after a single non-improving step
(epsilon check). Requires use_scan=False for true early stopping.
restore_best : bool
If True, restore the parameters that achieved the lowest loss
(or validation loss if validation data is provided). Default:
False.
class_weight : dict or 'balanced', optional
Weights for each class. Dict maps class label to weight, e.g.
{0: 1.0, 1: 2.0, 2: 1.5}. 'balanced' auto-computes weights
inversely proportional to class frequencies. Default: None
(uniform).
gradient_accumulation_steps : int, optional
Accumulate gradients over this many steps before applying an
update. Effective batch size = batch_size *
gradient_accumulation_steps. Default: None (no accumulation).
ema_decay : float, optional
Exponential moving average decay for parameters
(0 < ema_decay < 1). After training, model parameters are
replaced with EMA-smoothed values. Typical values: 0.999,
0.9999. Default: None (no EMA).
freeze_params : list of str, optional
List of parameter group names to freeze (zero gradients).
E.g. ['backbone'] to freeze the backbone and only train
prototypes. Default: None (all parameters trainable).
lookahead : dict, optional
Enable lookahead optimizer wrapper. Dict with keys:
- 'sync_period': int (default 6) -- sync every k steps
- 'slow_step_size': float (default 0.5) -- interpolation factor
Default: None (no lookahead).
mixed_precision : str or None, optional
Compute dtype for mixed precision training. 'float16' or
'bfloat16'. Master weights stay in float32; forward/backward
pass runs in lower precision for ~2x speed and ~half memory on
GPU. Float16 uses static loss scaling to prevent gradient
underflow. Default: None (disabled).
"""
_valid_costs = ('contrastive', 'brier', 'cross_entropy')
_valid_responses = ('gaussian', 'student_t', 'uniform')
def __init__(self, n_prototypes=3, target_label=None, alpha=0.5,
cost_function='contrastive', response_type='gaussian',
sigma=0.1, gamma_resp=1.0, nu=1.0,
lambda_init=None, lambda_final=0.01, lambda_decay=None,
max_iter=100, lr=0.01, epsilon=1e-6, random_seed=42,
distance_fn=None, optimizer='adam', transfer_fn=None,
margin=0.0, callbacks=None, use_scan=True, batch_size=None,
lr_scheduler=None, lr_scheduler_kwargs=None,
prototypes_initializer=None, patience=None,
restore_best=False, class_weight=None,
gradient_accumulation_steps=None, ema_decay=None,
freeze_params=None, lookahead=None, mixed_precision=None):
super().__init__(
n_prototypes_per_class=n_prototypes,
max_iter=max_iter, lr=lr, epsilon=epsilon,
random_seed=random_seed, distance_fn=distance_fn,
optimizer=optimizer, transfer_fn=transfer_fn, margin=margin,
callbacks=callbacks, use_scan=use_scan, batch_size=batch_size,
lr_scheduler=lr_scheduler,
lr_scheduler_kwargs=lr_scheduler_kwargs,
prototypes_initializer=prototypes_initializer,
patience=patience, restore_best=restore_best,
class_weight=class_weight,
gradient_accumulation_steps=gradient_accumulation_steps,
ema_decay=ema_decay, freeze_params=freeze_params,
lookahead=lookahead, mixed_precision=mixed_precision,
)
self.n_prototypes = n_prototypes
self.target_label = target_label
self.alpha = alpha
self.sigma = sigma
self.gamma_resp = gamma_resp
self.nu = nu
self.lambda_init = lambda_init
self.lambda_final = lambda_final
self.lambda_decay = lambda_decay
if cost_function not in self._valid_costs:
raise ValueError(
f"cost_function must be one of {self._valid_costs}, "
f"got '{cost_function}'"
)
self.cost_function = cost_function
if response_type not in self._valid_responses:
raise ValueError(
f"response_type must be one of {self._valid_responses}, "
f"got '{response_type}'"
)
self.response_type = response_type
# Fitted attributes
self.thetas_ = None
self.lambda_ = None
self._target_label = None
self._non_target_label = None
# Freeze lambda_ng from optimizer (decayed manually)
if self.freeze_params is None:
self.freeze_params = ['lambda_ng']
elif 'lambda_ng' not in self.freeze_params:
self.freeze_params = list(self.freeze_params) + ['lambda_ng']
def _get_resume_params(self, params):
params['thetas'] = self.thetas_
lam = self.lambda_ if self.lambda_ is not None else (
self._lambda_init_actual if hasattr(self, '_lambda_init_actual')
else 1.0
)
params['lambda_ng'] = jnp.array(lam, dtype=jnp.float32)
return params
def _init_state(self, X, y, key):
# Determine target and non-target labels
classes = jnp.unique(y)
if self.target_label is not None:
self._target_label = int(self.target_label)
else:
# Auto-detect: most frequent class is target
counts = jnp.array([jnp.sum(y == c) for c in classes])
self._target_label = int(classes[jnp.argmax(counts)])
non_target = classes[classes != self._target_label]
self._non_target_label = int(non_target[0]) if len(non_target) > 0 else (
1 - self._target_label
)
# Filter target class data for prototype initialization
target_mask = (y == self._target_label)
X_target = X[target_mask]
y_target = jnp.full(X_target.shape[0], self._target_label, dtype=jnp.int32)
key1, key2 = jax.random.split(key)
# Initialize prototypes from target class data
if self.prototypes_initializer is not None:
prototypes, _ = self._init_prototypes(
X_target, y_target, self.n_prototypes, key1
)
else:
prototypes, _ = stratified_selection_init(
X_target, y_target, self.n_prototypes, key1
)
proto_labels = jnp.full(
self.n_prototypes, self._target_label, dtype=jnp.int32
)
# Initialize thetas: sqrt of mean squared distance per prototype
from prosemble.core.distance import squared_euclidean_distance_matrix
dists = squared_euclidean_distance_matrix(X_target, prototypes)
thetas = jnp.sqrt(jnp.mean(dists, axis=0) + 1e-10)
# Lambda NG setup
lambda_init = (
self.lambda_init if self.lambda_init is not None
else self.n_prototypes / 2.0
)
lambda_init = max(lambda_init, self.lambda_final + 1e-6)
self._lambda_init_actual = lambda_init
if self.lambda_decay is not None:
self._lambda_decay = self.lambda_decay
else:
self._lambda_decay = (
self.lambda_final / lambda_init
) ** (1.0 / self.max_iter)
params = {
'prototypes': prototypes,
'thetas': thetas,
'lambda_ng': jnp.array(lambda_init, dtype=jnp.float32),
}
opt_state = self._optimizer.init(params)
from prosemble.models.prototype_base import SupervisedState
state = SupervisedState(
prototypes=prototypes,
opt_state=opt_state,
loss=jnp.array(float('inf')),
iteration=0,
converged=False,
)
return state, params, proto_labels
def _compute_loss(self, params, X, y, proto_labels):
prototypes = params['prototypes']
thetas = params['thetas']
lambda_ng = params['lambda_ng']
n_protos = prototypes.shape[0]
# Squared Euclidean distances: (n, K)
diff = X[:, None, :] - prototypes[None, :, :]
sq_distances = jnp.sum(diff ** 2, axis=2)
# Target / non-target masks
target_mask = (y == self._target_label)
# ===== Representation cost R (target data only) =====
# Neural Gas ranking over all samples but weighted for target only
order = jnp.argsort(sq_distances, axis=1)
ranks = jnp.argsort(order, axis=1).astype(jnp.float32)
h_ng = jnp.exp(-ranks / (lambda_ng + 1e-10))
R_per_sample = jnp.sum(h_ng * sq_distances, axis=1)
R_per_sample = jnp.where(target_mask, R_per_sample, 0.0)
n_target = jnp.sum(target_mask) + 1e-10
R = jnp.sum(R_per_sample) / n_target
# ===== Classification cost C =====
# Response probability p(w_k | x)
if self.response_type == 'gaussian':
logits = -self.gamma_resp * sq_distances
p_k = jax.nn.softmax(logits, axis=1)
elif self.response_type == 'student_t':
p_unnorm = (1.0 + sq_distances / self.nu) ** (-(self.nu + 1) / 2)
p_k = p_unnorm / (jnp.sum(p_unnorm, axis=1, keepdims=True) + 1e-10)
else: # uniform
p_k = jnp.ones_like(sq_distances) / n_protos
# Sigmoid approximation of Heaviside: sgd_sigma(d, theta_k) = sigma((theta_k - d) / sigma)
# Keep thetas positive
thetas_pos = jnp.maximum(thetas, 1e-6)
heaviside = jax.nn.sigmoid(
(thetas_pos[None, :] - sq_distances) / (self.sigma + 1e-10)
)
# Local responsibility: r(x, w_k) = p(w_k|x) * H(theta_k - d(x, w_k))
responsibility = p_k * heaviside
# Summed responsibility per sample
total_resp = jnp.sum(responsibility, axis=1)
total_resp = jnp.clip(total_resp, 1e-10, 1.0 - 1e-10)
# Binary labels: 1 for target, 0 for non-target
y_binary = target_mask.astype(jnp.float32)
if self.cost_function == 'contrastive':
# Probabilistic confusion matrix
TP = jnp.sum(y_binary * total_resp)
FN = jnp.sum(y_binary) - TP
FP = jnp.sum((1.0 - y_binary) * total_resp)
TN = jnp.sum(1.0 - y_binary) - FP
# CS_W = 1 - (TP·TN - FP·FN) / ((TP+FP)(TN+FN))
numerator = TP * TN - FP * FN
denominator = (TP + FP + 1e-10) * (TN + FN + 1e-10)
C = 1.0 - numerator / denominator
elif self.cost_function == 'brier':
# Brier Score: mean (y - sum_k r)^2
C = jnp.mean((y_binary - total_resp) ** 2)
else: # cross_entropy
# Binary Cross Entropy
C = -jnp.mean(
y_binary * jnp.log(total_resp) +
(1.0 - y_binary) * jnp.log(1.0 - total_resp)
)
return self.alpha * R + (1.0 - self.alpha) * C
def _post_update(self, params):
# Keep thetas positive
thetas = jnp.maximum(params['thetas'], 1e-6)
# Decay lambda_ng
new_lambda = params['lambda_ng'] * self._lambda_decay
new_lambda = jnp.maximum(new_lambda, self.lambda_final)
return {**params, 'thetas': thetas, 'lambda_ng': new_lambda}
def _extract_results(self, params, proto_labels, loss_history, n_iter,
**kwargs):
super()._extract_results(
params, proto_labels, loss_history, n_iter, **kwargs
)
self.thetas_ = jnp.maximum(params['thetas'], 1e-6)
self.lambda_ = float(params['lambda_ng'])
[docs]
def predict(self, X):
"""Predict target or non-target labels.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Returns
-------
labels : array of shape (n_samples,)
target_label for target, non_target_label for outliers.
"""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
scores = self.decision_function(X)
return jnp.where(
scores >= 0.5, self._target_label, self._non_target_label
).astype(jnp.int32)
[docs]
def predict_with_reject(self, X, upper=0.5, lower=None, reject_label=-1):
"""Predict with a reject option for uncertain samples.
Samples with scores between lower and upper are rejected
(labeled reject_label) instead of being forced into a class.
Parameters
----------
X : array-like of shape (n_samples, n_features)
upper : float
Scores >= upper are classified as target. Default: 0.5.
lower : float, optional
Scores < lower are classified as non-target. Scores in
[lower, upper) are rejected. Default: same as upper
(no rejection zone, equivalent to predict).
reject_label : int
Label for rejected samples. Default: -1.
Returns
-------
labels : array of shape (n_samples,)
"""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
if lower is None:
lower = upper
scores = self.decision_function(X)
labels = jnp.full(scores.shape, reject_label, dtype=jnp.int32)
labels = jnp.where(scores >= upper, self._target_label, labels)
labels = jnp.where(scores < lower, self._non_target_label, labels)
return labels
def predict_proba(self, X):
"""Predict probability of being target class.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Returns
-------
proba : array of shape (n_samples,)
Probability of each sample belonging to the target class.
"""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
return self.decision_function(X)
[docs]
def decision_function(self, X):
"""Compute summed responsibility scores.
Scores near 1.0 indicate target class, near 0.0 indicate outlier.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Returns
-------
scores : array of shape (n_samples,)
"""
self._check_fitted()
X = jnp.asarray(X, dtype=jnp.float32)
diff = X[:, None, :] - self.prototypes_[None, :, :]
sq_distances = jnp.sum(diff ** 2, axis=2)
n_protos = self.prototypes_.shape[0]
# Response probability
if self.response_type == 'gaussian':
logits = -self.gamma_resp * sq_distances
p_k = jax.nn.softmax(logits, axis=1)
elif self.response_type == 'student_t':
p_unnorm = (
(1.0 + sq_distances / self.nu) ** (-(self.nu + 1) / 2)
)
p_k = p_unnorm / (
jnp.sum(p_unnorm, axis=1, keepdims=True) + 1e-10
)
else: # uniform
p_k = jnp.ones_like(sq_distances) / n_protos
# Sigmoid Heaviside
heaviside = jax.nn.sigmoid(
(self.thetas_[None, :] - sq_distances) / (self.sigma + 1e-10)
)
# Summed responsibility
responsibility = p_k * heaviside
return jnp.clip(jnp.sum(responsibility, axis=1), 0.0, 1.0)
@property
def visibility_radii(self):
"""Return the learned visibility radii :math:`\\theta_k` for each prototype."""
if self.thetas_ is None:
raise ValueError("Model not fitted. Call fit() first.")
return self.thetas_
def _get_quantizable_attrs(self):
attrs = super()._get_quantizable_attrs()
if self.thetas_ is not None:
attrs.append('thetas_')
return attrs
def _get_fitted_arrays(self):
arrays = super()._get_fitted_arrays()
if self.thetas_ is not None:
arrays['thetas_'] = np.asarray(self.thetas_)
if self.lambda_ is not None:
arrays['lambda_'] = np.asarray(self.lambda_)
if self._target_label is not None:
arrays['_target_label'] = np.asarray(self._target_label)
if self._non_target_label is not None:
arrays['_non_target_label'] = np.asarray(self._non_target_label)
return arrays
def _set_fitted_arrays(self, arrays):
super()._set_fitted_arrays(arrays)
if 'thetas_' in arrays:
self.thetas_ = jnp.asarray(arrays['thetas_'])
if 'lambda_' in arrays:
self.lambda_ = float(arrays['lambda_'])
if '_target_label' in arrays:
self._target_label = int(arrays['_target_label'])
if '_non_target_label' in arrays:
self._non_target_label = int(arrays['_non_target_label'])
def _get_hyperparams(self):
hp = super()._get_hyperparams()
# Remove n_prototypes_per_class — we use n_prototypes instead
hp.pop('n_prototypes_per_class', None)
hp['n_prototypes'] = self.n_prototypes
hp['target_label'] = self.target_label
hp['alpha'] = self.alpha
hp['cost_function'] = self.cost_function
hp['response_type'] = self.response_type
hp['sigma'] = self.sigma
hp['gamma_resp'] = self.gamma_resp
hp['nu'] = self.nu
hp['lambda_init'] = self.lambda_init
hp['lambda_final'] = self.lambda_final
hp['lambda_decay'] = self.lambda_decay
return hp