Source code for prosemble.models.riemannian_srng

"""
Supervised Riemannian Neural Gas (RiemannianSRNG).

Extends GLVQ-style supervised classification to Riemannian manifolds
using geodesic distances and Neural Gas neighborhood cooperation.
Prototypes live on the manifold and are updated via projected gradient
descent (Euclidean gradient + manifold projection).

Supports SO(n), SPD(n), Grassmannian(n,k), and HyperbolicPoincare(d) manifolds.
"""

import jax
import jax.numpy as jnp
import numpy as np

from prosemble.models.prototype_base import SupervisedPrototypeModel, SupervisedState
from prosemble.core.activations import sigmoid_beta
from prosemble.core.manifolds import SO, SPD, Grassmannian, HyperbolicPoincare
from prosemble.core.protocols import Manifold


# ---------------------------------------------------------------------------
# Differentiable manifold operations for autodiff-based training
# ---------------------------------------------------------------------------

def _logm_spd_diff(M):
    """Differentiable matrix log for SPD matrices via eigendecomposition.

    Unlike ``jsl.funm(A, jnp.log)`` (which uses Schur decomposition with
    no JAX differentiation rule), this uses ``jnp.linalg.eigh`` which
    has full autodiff support.
    """
    eigvals, eigvecs = jnp.linalg.eigh(M)
    eigvals = jnp.maximum(eigvals, 1e-10)
    return eigvecs @ jnp.diag(jnp.log(eigvals)) @ eigvecs.T


def _so_chordal_distance_squared(R, S):
    """Chordal (Frobenius) distance squared on SO(n).

    .. math::

        d^2(R, S) = \\|R - S\\|_F^2

    This is a standard differentiable proxy for the geodesic distance
    on SO(n), widely used in rotation averaging (Hartley et al., 2013).
    It is monotonically related to the geodesic distance for small angles.
    """
    diff = R - S
    return jnp.sum(diff ** 2)


def _spd_distance_squared_diff(A, B):
    """Differentiable squared geodesic distance on SPD(n).

    .. math::

        d^2(A, B) = \\|\\log(A^{-1/2} B A^{-1/2})\\|_F^2

    Uses eigendecomposition-based logm (differentiable) instead of
    ``jsl.funm`` (not differentiable).
    """
    from prosemble.core.manifolds import inv_sqrt_spd
    A_isqrt = inv_sqrt_spd(A)
    M = A_isqrt @ B @ A_isqrt
    logM = _logm_spd_diff(M)
    return jnp.sum(logM ** 2)


def _grassmannian_distance_squared_diff(Q1, Q2):
    """Differentiable squared geodesic distance on Gr(n,k).

    Uses SVD + arccos, both of which have JAX differentiation rules.
    """
    M = Q1.T @ Q2
    svals = jnp.linalg.svd(M, compute_uv=False)
    svals = jnp.clip(svals, -1.0 + 1e-7, 1.0 - 1e-7)
    angles = jnp.arccos(svals)
    return jnp.sum(angles ** 2)


def _so_log_map_diff(R, S):
    """Differentiable tangent vector approximation on SO(n).

    Returns the skew-symmetric part of R^T S mapped to the tangent
    space at R. This is the first-order approximation of the true
    logarithmic map, exact when R and S are close.

    .. math::

        \\text{Log}_R(S) \\approx R \\cdot \\text{skew}(R^T S)
        = R \\cdot \\frac{R^T S - S^T R}{2}
    """
    RtS = R.T @ S
    skew = (RtS - RtS.T) / 2.0
    return R @ skew


def _spd_log_map_diff(A, B):
    """Differentiable log map on SPD(n) via eigendecomposition.

    .. math::

        \\text{Log}_A(B) = A^{1/2} \\log(A^{-1/2} B A^{-1/2}) A^{1/2}
    """
    from prosemble.core.manifolds import sqrt_spd, inv_sqrt_spd
    A_sqrt = sqrt_spd(A)
    A_isqrt = inv_sqrt_spd(A)
    M = A_isqrt @ B @ A_isqrt
    return A_sqrt @ _logm_spd_diff(M) @ A_sqrt


def _grassmannian_log_map_diff(Q1, Q2):
    """Differentiable log map on Grassmannian with safe gradient.

    The standard log map has a gradient singularity when Q1 and Q2 span
    identical or nearly identical subspaces (sin(theta) -> 0). This
    version uses the tangent space projection which is always well-conditioned:

    .. math::

        \\text{Log}_{Q_1}(Q_2) \\approx Q_2 - Q_1 (Q_1^T Q_2)

    This is the orthogonal projection of Q2 onto the normal space of Q1's
    column span, which equals the exact log map to first order and is
    smooth everywhere.
    """
    return Q2 - Q1 @ (Q1.T @ Q2)


def _hyperbolic_distance_squared_diff(x, y, eps=1e-5):
    """Differentiable squared geodesic distance on the Poincare ball.

    .. math::

        d^2(x, y) = \\left(\\text{arcosh}\\left(1 + \\frac{2\\|x - y\\|^2}
        {(1 - \\|x\\|^2)(1 - \\|y\\|^2)}\\right)\\right)^2

    Parameters
    ----------
    x : array of shape (d,)
        Point in the Poincare ball.
    y : array of shape (d,)
        Point in the Poincare ball.
    eps : float
        Numerical stability constant.

    Returns
    -------
    float
        Squared geodesic distance.
    """
    diff_sq = jnp.sum((x - y) ** 2)
    x_sq = jnp.sum(x ** 2)
    y_sq = jnp.sum(y ** 2)
    denom = (1.0 - x_sq) * (1.0 - y_sq)
    arg = 1.0 + 2.0 * diff_sq / (denom + eps)
    # Clamp arg >= 1 + eps to prevent gradient blow-up at z=1
    # where d/dz arcosh(z) = 1/sqrt(z^2 - 1) diverges
    arg = jnp.maximum(arg, 1.0 + 1e-6)
    dist = jnp.arccosh(arg)
    return dist ** 2


def _hyperbolic_log_map_diff(x, y, eps=1e-5):
    """Differentiable log map on the Poincare ball.

    .. math::

        \\text{Log}_x(y) = \\frac{2}{\\lambda_x} \\text{arctanh}(\\|-x \\oplus y\\|)
        \\cdot \\frac{-x \\oplus y}{\\|-x \\oplus y\\|}

    where :math:`\\lambda_x = 2 / (1 - \\|x\\|^2)` is the conformal factor
    and :math:`\\oplus` is Mobius addition.

    Parameters
    ----------
    x : array of shape (d,)
        Base point in the Poincare ball.
    y : array of shape (d,)
        Target point in the Poincare ball.
    eps : float
        Numerical stability constant.

    Returns
    -------
    array of shape (d,)
        Tangent vector at x pointing toward y.
    """
    from prosemble.core.manifolds import _mobius_add, _conformal_factor
    neg_x = -x
    add_result = _mobius_add(neg_x, y, eps=eps)
    # Gradient-safe norm: jnp.linalg.norm has 0/0 gradient at zero
    norm_sq = jnp.sum(add_result ** 2)
    norm = jnp.sqrt(norm_sq + 1e-16)
    # Clamp norm for arctanh domain (-1, 1)
    norm_clamped = jnp.minimum(norm, 1.0 - 1e-7)
    lam_x = _conformal_factor(x, eps=eps)
    # arctanh(t)/t → 1 as t → 0; use this limit to avoid 0/0 gradient
    coeff = jnp.where(norm > 1e-7, jnp.arctanh(norm_clamped) / norm, 1.0)
    return (2.0 / lam_x) * coeff * add_result


[docs] class RiemannianSRNG(SupervisedPrototypeModel): """Supervised Riemannian Neural Gas. Combines three key ideas: - GLVQ loss: :math:`(d^+ - d^-) / (d^+ + d^-)` for margin-based classification - Neural Gas cooperation: all same-class prototypes participate in the loss, weighted by rank via :math:`\\exp(-\\text{rank} / \\gamma)` - Geodesic distance: :math:`d(x, w)` computed via the manifold's intrinsic metric (matrix logarithm + Frobenius norm) Prototypes live on the manifold and are updated via projected gradient descent: optax computes Euclidean gradients, then :meth:`manifold.project` maps prototypes back to the manifold after each step. The neighborhood range :math:`\\gamma` decays during training from :math:`\\gamma_{\\text{init}}` to :math:`\\gamma_{\\text{final}}`. When :math:`\\gamma \\to 0`, RiemannianSRNG recovers a Riemannian GLVQ. Parameters ---------- manifold : SO, SPD, or Grassmannian Riemannian manifold instance defining the geometry. beta : float Transfer function steepness parameter for sigmoid shaping. gamma_init : float, optional Initial neighborhood range for NG cooperation. Default: max prototypes per class / 2. gamma_final : float Final neighborhood range. Default: 0.01. gamma_decay : float, optional Per-step multiplicative decay factor for gamma. Default: computed from max_iter so gamma reaches gamma_final. tau : float Injectivity radius safety factor for manifold projection. Default: 0.95. n_prototypes_per_class : int Number of prototypes per class. max_iter : int Maximum training iterations. lr : float Learning rate. epsilon : float Convergence threshold on loss change. random_seed : int Random seed for reproducibility. optimizer : str or optax optimizer, optional Optimizer name ('adam', 'sgd') or optax GradientTransformation. Default: 'adam'. transfer_fn : callable, optional Transfer function for loss shaping (default: identity). margin : float Margin for loss computation. callbacks : list, optional List of Callback objects. use_scan : bool If True, use jax.lax.scan for training (faster, JIT-compiled). If False (default), use a Python for-loop with true early stopping. batch_size : int, optional Mini-batch size. If None (default), use full-batch training. lr_scheduler : str or optax.Schedule, optional Learning rate schedule. Default: None. lr_scheduler_kwargs : dict, optional Keyword arguments for the learning rate scheduler. Default: None. prototypes_initializer : str or callable, optional How to initialize prototypes. Default: 'stratified_random'. patience : int, optional Number of consecutive epochs with no improvement before stopping. Default: None. restore_best : bool If True, restore parameters that achieved the lowest loss. Default: False. class_weight : dict or 'balanced', optional Weights for each class. Default: None (uniform). gradient_accumulation_steps : int, optional Accumulate gradients over this many steps. Default: None. ema_decay : float, optional Exponential moving average decay for parameters. Default: None. freeze_params : list of str, optional List of parameter group names to freeze. Default: None. lookahead : dict, optional Enable lookahead optimizer wrapper. Default: None. mixed_precision : str or None, optional Compute dtype for mixed precision training. Default: None. """ def __init__(self, manifold: Manifold, beta=10.0, gamma_init=None, gamma_final=0.01, gamma_decay=None, lr_ratio=0.5, tau=0.95, n_prototypes_per_class=1, max_iter=100, lr=0.01, epsilon=1e-6, random_seed=42, optimizer='adam', transfer_fn=None, margin=0.0, callbacks=None, use_scan=False, batch_size=None, lr_scheduler=None, lr_scheduler_kwargs=None, prototypes_initializer=None, patience=None, restore_best=False, class_weight=None, gradient_accumulation_steps=None, ema_decay=None, freeze_params=None, lookahead=None, mixed_precision=None): super().__init__( n_prototypes_per_class=n_prototypes_per_class, max_iter=max_iter, lr=lr, epsilon=epsilon, random_seed=random_seed, distance_fn=None, optimizer=optimizer, transfer_fn=transfer_fn, margin=margin, callbacks=callbacks, use_scan=use_scan, batch_size=batch_size, lr_scheduler=lr_scheduler, lr_scheduler_kwargs=lr_scheduler_kwargs, prototypes_initializer=prototypes_initializer, patience=patience, restore_best=restore_best, class_weight=class_weight, gradient_accumulation_steps=gradient_accumulation_steps, ema_decay=ema_decay, freeze_params=freeze_params, lookahead=lookahead, mixed_precision=mixed_precision, ) self.manifold = manifold self.beta = beta self.gamma_init = gamma_init self.gamma_final = gamma_final self.gamma_decay = gamma_decay self.lr_ratio = lr_ratio self.tau = tau self.gamma_ = None # Ensure gamma is frozen from optimizer (not trainable) if self.freeze_params is None: self.freeze_params = ['gamma'] elif 'gamma' not in self.freeze_params: self.freeze_params = list(self.freeze_params) + ['gamma'] def _reshape_to_manifold(self, flat, n_points): """Reshape flat array to manifold point shape. Parameters ---------- flat : array of shape (n_points, d_flat) n_points : int Returns ------- array of shape (n_points, *point_shape) """ return flat.reshape(n_points, *self.manifold.point_shape) def _diff_distance_squared(self, x, w): """Differentiable squared distance for a single pair of points. Dispatches to the appropriate differentiable distance based on manifold type. Used during training (autodiff). For inference, the exact geodesic distance can also be used. """ if isinstance(self.manifold, SO): return _so_chordal_distance_squared(x, w) elif isinstance(self.manifold, SPD): return _spd_distance_squared_diff(x, w) elif isinstance(self.manifold, Grassmannian): return _grassmannian_distance_squared_diff(x, w) elif isinstance(self.manifold, HyperbolicPoincare): return _hyperbolic_distance_squared_diff(x, w, eps=self.manifold.eps) else: return self.manifold.distance_squared(x, w) def _geodesic_distances(self, X_manifold, W_manifold): """Compute pairwise squared distance matrix (differentiable). Parameters ---------- X_manifold : array of shape (n_samples, *point_shape) W_manifold : array of shape (n_prototypes, *point_shape) Returns ------- distances : array of shape (n_samples, n_prototypes) """ dist_to_all = jax.vmap(self._diff_distance_squared, in_axes=(None, 0)) dist_matrix = jax.vmap(dist_to_all, in_axes=(0, None)) return dist_matrix(X_manifold, W_manifold) def _get_resume_params(self, params): gamma = params.get('gamma', jnp.array(self.gamma_final)) return { 'prototypes': params['prototypes'], 'gamma': gamma, } def _init_state(self, X, y, key): key1, key2 = jax.random.split(key) prototypes, proto_labels = self._init_prototypes( X, y, self.n_prototypes_per_class, key1 ) # Project initial prototypes to manifold n_protos = prototypes.shape[0] protos_manifold = self._reshape_to_manifold(prototypes, n_protos) protos_manifold = jax.vmap(self.manifold.project)(protos_manifold) prototypes = protos_manifold.reshape(n_protos, -1) # Compute gamma_init from prototype count if not set if isinstance(self.n_prototypes_per_class, int): max_per_class = self.n_prototypes_per_class elif isinstance(self.n_prototypes_per_class, dict): max_per_class = max(self.n_prototypes_per_class.values()) else: max_per_class = max(self.n_prototypes_per_class) gamma_init = self.gamma_init if self.gamma_init is not None else max_per_class / 2.0 gamma_init = max(gamma_init, self.gamma_final + 1e-6) self._gamma_init_actual = gamma_init # Compute decay factor if self.gamma_decay is not None: self._gamma_decay = self.gamma_decay else: self._gamma_decay = (self.gamma_final / gamma_init) ** (1.0 / self.max_iter) params = { 'prototypes': prototypes, 'gamma': jnp.array(gamma_init, dtype=jnp.float32), } opt_state = self._optimizer.init(params) state = SupervisedState( prototypes=prototypes, opt_state=opt_state, loss=jnp.array(float('inf')), iteration=0, converged=False, ) return state, params, proto_labels def _compute_loss(self, params, X, y, proto_labels): prototypes = params['prototypes'] gamma = params['gamma'] # Reshape to manifold n = X.shape[0] p = prototypes.shape[0] X_m = self._reshape_to_manifold(X, n) W_m = self._reshape_to_manifold(prototypes, p) # 1. Geodesic distance matrix distances = self._geodesic_distances(X_m, W_m) # (n, p) # 2. Compute ranks within same-class prototypes same_class = (y[:, None] == proto_labels[None, :]) # (n, p) INF = jnp.finfo(distances.dtype).max d_same = jnp.where(same_class, distances, INF) order = jnp.argsort(d_same, axis=1) ranks = jnp.argsort(order, axis=1).astype(jnp.float32) # 3. Neighborhood function h = exp(-rank / gamma) h = jnp.exp(-ranks / (gamma + 1e-10)) h = jnp.where(same_class, h, 0.0) # 4. Normalize per sample C = jnp.sum(h, axis=1, keepdims=True) h_normalized = h / (C + 1e-10) # 5. Closest different-class prototype distance d_diff = jnp.where(~same_class, distances, INF) dm = jnp.min(d_diff, axis=1) # Separate learning rates (Hammer et al. 2003: epsilon^- = lr_ratio * epsilon^+) # Scale gradient through dm by lr_ratio; forward pass unchanged. dm = jax.lax.stop_gradient(dm) + self.lr_ratio * ( dm - jax.lax.stop_gradient(dm)) # 6. GLVQ mu mu = (distances - dm[:, None]) / (distances + dm[:, None] + 1e-10) # 7. Transfer function transfer = self.transfer_fn or sigmoid_beta cost = transfer(mu + self.margin, self.beta) # 8. Rank-weighted sum weighted_cost = jnp.sum(h_normalized * cost, axis=1) return jnp.mean(weighted_cost) def _post_update(self, params): # Decay gamma new_gamma = params['gamma'] * self._gamma_decay new_gamma = jnp.maximum(new_gamma, self.gamma_final) # Project prototypes back to manifold prototypes = params['prototypes'] n_protos = prototypes.shape[0] protos_manifold = self._reshape_to_manifold(prototypes, n_protos) protos_manifold = jax.vmap(self.manifold.project)(protos_manifold) prototypes = protos_manifold.reshape(n_protos, -1) return {**params, 'gamma': new_gamma, 'prototypes': prototypes} def _extract_results(self, params, proto_labels, loss_history, n_iter, **kwargs): super()._extract_results(params, proto_labels, loss_history, n_iter, **kwargs) self.gamma_ = float(params['gamma'])
[docs] def predict(self, X): """Predict class labels using geodesic distance. Parameters ---------- X : array-like of shape (n_samples, n_features_flat) Returns ------- labels : array of shape (n_samples,) """ self._check_fitted() X = jnp.asarray(X, dtype=jnp.float32) n = X.shape[0] p = self.prototypes_.shape[0] X_m = self._reshape_to_manifold(X, n) W_m = self._reshape_to_manifold(self.prototypes_, p) distances = self._geodesic_distances(X_m, W_m) from prosemble.core.competitions import wtac return wtac(distances, self.prototype_labels_)
def _get_quantizable_attrs(self): return {'prototypes_': self.prototypes_} def _get_hyperparams(self): hp = super()._get_hyperparams() hp['beta'] = self.beta hp['gamma_init'] = self.gamma_init hp['gamma_final'] = self.gamma_final hp['gamma_decay'] = self.gamma_decay hp['lr_ratio'] = self.lr_ratio hp['tau'] = self.tau # Store manifold type and params for reconstruction manifold = self.manifold hp['manifold_type'] = type(manifold).__name__ if hasattr(manifold, 'n'): hp['manifold_n'] = manifold.n if hasattr(manifold, 'k'): hp['manifold_k'] = manifold.k if hasattr(manifold, 'd'): hp['manifold_d'] = manifold.d return hp @classmethod def _reconstruct_manifold(cls, hp): """Reconstruct manifold from saved hyperparameters.""" from prosemble.core.manifolds import SO, SPD, Grassmannian, HyperbolicPoincare mtype = hp.get('manifold_type', '') if mtype == 'SO': return SO(int(hp['manifold_n'])) elif mtype == 'SPD': return SPD(int(hp['manifold_n'])) elif mtype == 'Grassmannian': return Grassmannian(int(hp['manifold_n']), int(hp['manifold_k'])) elif mtype == 'HyperbolicPoincare': return HyperbolicPoincare(int(hp['manifold_d'])) else: raise ValueError(f"Unknown manifold type: {mtype}") @classmethod def _pre_load_construct(cls, hyperparams, metadata): manifold = cls._reconstruct_manifold(hyperparams) hyperparams.pop('manifold_type', None) hyperparams.pop('manifold_n', None) hyperparams.pop('manifold_k', None) hyperparams.pop('manifold_d', None) hyperparams['manifold'] = manifold return hyperparams