Source code for prosemble.core.losses

"""
Loss functions for prototype-based learning.

All loss functions are differentiable by jax.grad and JIT-compatible.
They use jnp.where masking (not boolean indexing) for d+/d- extraction.
"""

import jax.numpy as jnp
from jax import jit
from functools import partial

from .activations import identity


# --- Helpers ---

@jit
def _get_dp_dm(distances, target_labels, prototype_labels):
    """Extract d+ (min same-class) and d- (min different-class) distances.

    Parameters
    ----------
    distances : array of shape (n, p)
        Distance from each sample to each prototype.
    target_labels : array of shape (n,)
        True labels.
    prototype_labels : array of shape (p,)
        Prototype labels.

    Returns
    -------
    dp : array of shape (n,)
        Distance to closest same-class prototype.
    dm : array of shape (n,)
        Distance to closest different-class prototype.
    """
    # same_class[i, j] = True if sample i and prototype j share label
    same_class = (target_labels[:, None] == prototype_labels[None, :])
    diff_class = ~same_class

    INF = jnp.finfo(distances.dtype).max
    dp = jnp.min(jnp.where(same_class, distances, INF), axis=1)
    dm = jnp.min(jnp.where(diff_class, distances, INF), axis=1)
    return dp, dm


@jit
def _get_dp_dm_with_indices(distances, target_labels, prototype_labels):
    """Extract d+/d- with winner indices.

    Returns
    -------
    dp, dm : arrays of shape (n,)
    wp, wm : arrays of shape (n,) — indices of winning prototypes
    """
    same_class = (target_labels[:, None] == prototype_labels[None, :])
    diff_class = ~same_class

    INF = jnp.finfo(distances.dtype).max
    d_same = jnp.where(same_class, distances, INF)
    d_diff = jnp.where(diff_class, distances, INF)

    dp = jnp.min(d_same, axis=1)
    dm = jnp.min(d_diff, axis=1)
    wp = jnp.argmin(d_same, axis=1)
    wm = jnp.argmin(d_diff, axis=1)
    return dp, dm, wp, wm


# --- GLVQ Loss ---

[docs] @jit def glvq_loss(distances, target_labels, prototype_labels, margin=0.0): """Generalized LVQ loss. :math:`\mu_i = \frac{d^+_i - d^-_i}{d^+_i + d^-_i}` Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) margin : float Margin added to mu before transfer. Returns ------- scalar Mean loss over samples. """ dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) mu = (dp - dm) / (dp + dm + 1e-10) return jnp.mean(mu + margin)
[docs] def glvq_loss_with_transfer(distances, target_labels, prototype_labels, transfer_fn=identity, margin=0.0, beta=10.0): """GLVQ loss with configurable transfer function. :math:`\text{loss} = \text{mean}(f(\mu + \text{margin}, \beta))` Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) transfer_fn : callable Activation function (identity, sigmoid_beta, swish_beta). margin : float beta : float Transfer function parameter. Returns ------- scalar """ dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) mu = (dp - dm) / (dp + dm + 1e-10) return jnp.mean(transfer_fn(mu + margin, beta))
# --- LVQ1 / LVQ2.1 Losses ---
[docs] @jit def lvq1_loss(distances, target_labels, prototype_labels): """LVQ1 loss: d+ when correct, -d- when wrong. Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) Returns ------- scalar """ dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) # When d+ < d- (correct): loss = d+ (want to minimize) # When d+ > d- (wrong): loss = -d- (want to push away) mu = jnp.where(dp <= dm, dp, -dm) return jnp.mean(mu)
[docs] @jit def lvq21_loss(distances, target_labels, prototype_labels): """LVQ2.1 loss: d+ - d- (unnormalized). Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) Returns ------- scalar """ dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) return jnp.mean(dp - dm)
# --- Probabilistic Losses --- @jit def _class_probabilities(distances, target_labels, prototype_labels, sigma): """Compute Gaussian mixture class probabilities. .. math:: p(k|x) = \frac{\exp(-d^2 / 2\sigma^2)}{\sum \exp(-d^2 / 2\sigma^2)}, \quad P(\text{class}|x) = \sum_{k \in \text{class}} p(k|x) Returns ------- whole : array (n,) — total probability correct : array (n,) — probability of correct class wrong : array (n,) — probability of wrong classes """ # Compute conditional probabilities log_probs = -distances / (2.0 * sigma ** 2) # Normalize via log-sum-exp log_norm = jnp.max(log_probs, axis=1, keepdims=True) probs = jnp.exp(log_probs - log_norm) probs = probs / jnp.sum(probs, axis=1, keepdims=True) # Sum probabilities per class same_class = (target_labels[:, None] == prototype_labels[None, :]) correct = jnp.sum(probs * same_class, axis=1) whole = jnp.sum(probs, axis=1) # should be 1.0 wrong = whole - correct return whole, correct, wrong
[docs] @jit def nllr_loss(distances, target_labels, prototype_labels, sigma=1.0): """Negative Log-Likelihood Ratio loss (for SLVQ). :math:`\text{loss} = -\log(P(\text{correct}) / P(\text{wrong}))` Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) sigma : float Bandwidth of Gaussian mixture. Returns ------- scalar """ _, correct, wrong = _class_probabilities( distances, target_labels, prototype_labels, sigma ) likelihood = correct / (wrong + 1e-10) return jnp.mean(-jnp.log(likelihood + 1e-10))
[docs] @jit def rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0): """Robust Soft LVQ loss (for RSLVQ). :math:`\text{loss} = -\log(P(\text{correct}) / P(\text{all}))` Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) sigma : float Returns ------- scalar """ whole, correct, _ = _class_probabilities( distances, target_labels, prototype_labels, sigma ) likelihood = correct / (whole + 1e-10) return jnp.mean(-jnp.log(likelihood + 1e-10))
[docs] @jit def ng_rslvq_loss(distances, target_labels, prototype_labels, sigma=1.0, gamma=1.0): """RSLVQ loss with Neural Gas rank-based neighborhood cooperation. Combines Gaussian mixture prototype probabilities with NG rank weights to create a neighborhood-cooperative probabilistic assignment. Gaussian: :math:`p(k|x) = \exp(-d_k / 2\sigma^2) / \sum_j \exp(-d_j / 2\sigma^2)` NG weights: :math:`h_k = \exp(-\text{rank}_k / \gamma) / \sum_j \exp(-\text{rank}_j / \gamma)` Combined: :math:`w_k = p(k|x) \cdot h_k / \sum_j p(j|x) \cdot h_j` The loss is :math:`-\log(\sum_{k \in \text{correct}} w_k)`. Parameters ---------- distances : array of shape (n, p) Squared distances from samples to prototypes. target_labels : array of shape (n,) True class labels for samples. prototype_labels : array of shape (p,) Class labels assigned to prototypes. sigma : float Bandwidth of Gaussian mixture. gamma : float Neural Gas neighborhood range. Returns ------- scalar Mean negative log-likelihood with NG cooperation. """ # 1. Gaussian mixture probabilities (numerically stable) log_probs = -distances / (2.0 * sigma ** 2) log_norm = jnp.max(log_probs, axis=1, keepdims=True) probs = jnp.exp(log_probs - log_norm) probs = probs / (jnp.sum(probs, axis=1, keepdims=True) + 1e-10) # 2. NG rank-based weighting (double argsort for ranks) order = jnp.argsort(distances, axis=1) ranks = jnp.argsort(order, axis=1).astype(jnp.float32) h = jnp.exp(-ranks / (gamma + 1e-10)) h = h / (jnp.sum(h, axis=1, keepdims=True) + 1e-10) # 3. Combined weights (Gaussian * NG), renormalized weighted_probs = h * probs weighted_probs = weighted_probs / (jnp.sum(weighted_probs, axis=1, keepdims=True) + 1e-10) # 4. Sum correct-class weighted probabilities same_class = (target_labels[:, None] == prototype_labels[None, :]) correct = jnp.sum(weighted_probs * same_class, axis=1) # 5. RSLVQ objective: -log(P_correct) return jnp.mean(-jnp.log(correct + 1e-10))
# --- Cross-Entropy LVQ Loss ---
[docs] @partial(jit, static_argnums=(3,)) def cross_entropy_lvq_loss(distances, target_labels, prototype_labels, n_classes): """Cross-entropy LVQ loss (for CELVQ). 1. Min distances per class via masking 2. Negate to get logits (closer = higher) 3. Cross-entropy against true labels Parameters ---------- distances : array of shape (n, p) target_labels : array of shape (n,) prototype_labels : array of shape (p,) n_classes : int Returns ------- scalar """ from .pooling import stratified_min_pooling class_dists = stratified_min_pooling(distances, prototype_labels, n_classes) logits = -class_dists # negate: smaller distance = larger logit # Numerically stable cross-entropy log_probs = jax.nn.log_softmax(logits, axis=1) target_one_hot = jax.nn.one_hot(target_labels, n_classes) return -jnp.mean(jnp.sum(target_one_hot * log_probs, axis=1))
# --- Margin Loss (for CBC) ---
[docs] @jit def margin_loss(y_pred, y_true_one_hot, margin=0.3): """Margin loss for CBC. :math:`\text{loss} = \text{ReLU}(\max(\text{wrong}) - \text{correct} + \text{margin})` Parameters ---------- y_pred : array of shape (n, n_classes) Predicted class probabilities. y_true_one_hot : array of shape (n, n_classes) One-hot encoded true labels. margin : float Returns ------- scalar """ correct = jnp.sum(y_true_one_hot * y_pred, axis=-1) wrong_max = jnp.max(y_pred - y_true_one_hot * 1e9, axis=-1) return jnp.mean(jax.nn.relu(wrong_max - correct + margin))
# --- Neural Gas Energy ---
[docs] @jit def neural_gas_energy(distances, lam): """Neural Gas energy function. .. math:: E = \sum_k h(\text{rank}_k, \lambda) \cdot d(x, w_k), \quad h(\text{rank}, \lambda) = \exp(-\text{rank} / \lambda) Parameters ---------- distances : array of shape (n, p) lam : float Neighborhood range parameter. Returns ------- scalar """ # Rank prototypes by distance for each sample order = jnp.argsort(distances, axis=1) ranks = jnp.argsort(order, axis=1).astype(jnp.float32) h = jnp.exp(-ranks / lam) return jnp.sum(h * distances)
# Need jax import for cross_entropy_lvq_loss import jax