Riemannian Models

Prosemble provides prototype-based learning on Riemannian manifolds. Prototypes live directly on the manifold and distances are computed via the intrinsic geodesic metric, preserving the geometric structure of the data.

Supported Manifolds

Manifold

Points

Applications

SO(n)

\(n \times n\) rotation matrices

Robotics (grasp/pose), 3D vision, structural biology

SPD(n)

\(n \times n\) symmetric positive definite matrices

EEG/BCI (covariance matrices), diffusion tensor imaging

Gr(n, k)

\(k\)-dimensional subspaces of \(\mathbb{R}^n\)

Hyperspectral imaging, video analysis, subspace tracking

Manifold

Points

Applications

HyperbolicPoincare(d)

Vectors in \(\mathbb{B}^d = \{x \in \mathbb{R}^d : \|x\| < 1\}\)

Hierarchical data, taxonomy embeddings, NLP trees

All four are available from prosemble.core.manifolds:

from prosemble.core.manifolds import SO, SPD, Grassmannian, HyperbolicPoincare

so3 = SO(3)           # 3x3 rotation matrices
spd4 = SPD(4)         # 4x4 SPD matrices
gr5_2 = Grassmannian(5, 2)  # 2D subspaces of R^5
hyp8 = HyperbolicPoincare(8)  # 8D Poincare ball

HyperbolicPoincare

The Poincare ball model of hyperbolic geometry. Points live in the open unit ball \(\mathbb{B}^d = \{x \in \mathbb{R}^d : \|x\| < 1\}\) with the Riemannian metric tensor \(g_x = \lambda_x^2 I\) where \(\lambda_x = 2/(1 - \|x\|^2)\) is the conformal factor.

The geodesic distance is:

\[d(x, y) = \text{arcosh}\!\left(1 + \frac{2\|x - y\|^2}{(1 - \|x\|^2)(1 - \|y\|^2)}\right)\]

Prototype updates use the exponential and logarithmic maps based on Mobius addition:

\[\text{Exp}_x(v) = x \oplus \left(\tanh\!\left(\frac{\lambda_x \|v\|}{2}\right) \cdot \frac{v}{\|v\|}\right)\]
\[\text{Log}_x(y) = \frac{2}{\lambda_x} \text{arctanh}(\|-x \oplus y\|) \cdot \frac{-x \oplus y}{\|-x \oplus y\|}\]

Hyperbolic space is particularly suited for data with hierarchical or tree-like structure, as distances grow exponentially toward the boundary.

from prosemble.models import RiemannianSRNG
from prosemble.core.manifolds import HyperbolicPoincare
import jax
import jax.numpy as jnp

manifold = HyperbolicPoincare(4)

# Generate data inside the Poincare ball
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (40, 4)) * 0.3
X = jax.vmap(manifold.project)(X)  # ensure points lie in ball
X = X.reshape(40, -1)
y = jnp.array([0] * 20 + [1] * 20)

model = RiemannianSRNG(
    manifold=manifold,
    n_prototypes_per_class=2,
    max_iter=50,
    lr=0.01,
)
model.fit(X, y)
labels = model.predict(X)

All supervised Riemannian models (RiemannianSRNG, RiemannianSMNG, RiemannianSLNG, RiemannianSTNG) and the unsupervised RiemannianNeuralGas work with HyperbolicPoincare without any code changes.

RiemannianSRNG

Supervised Riemannian Neural Gas. Combines GLVQ-style margin-based classification with Neural Gas neighbourhood cooperation on the manifold.

The loss for each sample uses all same-class prototypes, rank-weighted by \(h(k) = \exp(-k / \gamma)\):

\[\mathcal{L} = \sum_i \sum_j h(k_j) \cdot \sigma\!\left(\frac{d^+(x_i) - d^-(x_i)}{d^+(x_i) + d^-(x_i)}\right)\]

where \(d^+, d^-\) are the nearest same-class and different-class geodesic distances.

from prosemble.models import RiemannianSRNG
from prosemble.core.manifolds import SO
import jax.numpy as jnp

# Generate synthetic rotation data
key = jax.random.PRNGKey(0)
manifold = SO(3)
X = jnp.stack([manifold.random_point(jax.random.fold_in(key, i))
                for i in range(40)])
X = X.reshape(40, -1)  # flatten to (n_samples, 9)
y = jnp.array([0] * 20 + [1] * 20)

model = RiemannianSRNG(
    manifold=manifold,
    n_prototypes_per_class=2,
    max_iter=50,
    lr=0.01,
    gamma_final=0.01,
)
model.fit(X, y)
labels = model.predict(X)

RiemannianSMNG

Supervised Matrix Neural Gas on manifolds. Adds a global metric adaptation matrix \(\Omega\) that operates in the tangent space at each prototype:

\[d(x, w_k) = \|\Omega \cdot \text{Log}_{w_k}(x)\|^2\]

where \(\text{Log}_{w_k}\) is the Riemannian logarithmic map (tangent vector from \(w_k\) to \(x\)).

from prosemble.models import RiemannianSMNG
from prosemble.core.manifolds import SPD

manifold = SPD(3)
model = RiemannianSMNG(
    manifold=manifold,
    n_prototypes_per_class=2,
    latent_dim=4,
    max_iter=50,
    lr=0.01,
)
model.fit(X, y)

# Learned relevance matrix
print(model.relevance_matrix().shape)

RiemannianSLNG

Supervised Localized Matrix Neural Gas. Each prototype \(w_k\) has its own metric matrix \(\Omega_k\):

\[d(x, w_k) = \|\Omega_k \cdot \text{Log}_{w_k}(x)\|^2\]

This allows different prototypes to focus on different tangent directions, useful when the discriminative structure varies across the manifold.

from prosemble.models import RiemannianSLNG
from prosemble.core.manifolds import Grassmannian

manifold = Grassmannian(5, 2)
model = RiemannianSLNG(
    manifold=manifold,
    n_prototypes_per_class=2,
    latent_dim=3,
    max_iter=50,
    lr=0.01,
)
model.fit(X, y)

RiemannianSTNG

Supervised Tangent Neural Gas. Each prototype has a tangent subspace \(\Omega_k\), and distance is measured in the complement of that subspace (the residual after projection):

\[d(x, w_k) = \|\text{Log}_{w_k}(x) - \Omega_k \Omega_k^T \text{Log}_{w_k}(x)\|^2\]

This captures invariance structure — directions spanned by \(\Omega_k\) are ignored.

from prosemble.models import RiemannianSTNG
from prosemble.core.manifolds import SO

manifold = SO(3)
model = RiemannianSTNG(
    manifold=manifold,
    n_prototypes_per_class=2,
    subspace_dim=2,
    max_iter=50,
    lr=0.01,
)
model.fit(X, y)

RiemannianNeuralGas

Unsupervised Neural Gas on Riemannian manifolds. Distributes prototypes to match the data density using rank-based cooperation with geodesic distances and exponential map updates.

from prosemble.models import RiemannianNeuralGas
from prosemble.core.manifolds import SO

manifold = SO(3)
model = RiemannianNeuralGas(
    manifold=manifold,
    n_prototypes=5,
    max_iter=50,
    lr_init=0.3,
    lr_final=0.01,
    lambda_final=0.01,
)
model.fit(X)

labels = model.predict(X)       # nearest prototype assignment
distances = model.transform(X)  # geodesic distance matrix

Note

All Riemannian models use Python loops (not lax.scan) because manifold projection after each gradient step is not compatible with JAX’s functional loop primitives.

Choosing a Model

Model

Metric

Best For

RiemannianSRNG

Geodesic distance

General manifold classification

RiemannianSMNG

Global \(\Omega\) in tangent space

Feature selection on manifolds

RiemannianSLNG

Per-prototype \(\Omega_k\)

Heterogeneous tangent structure

RiemannianSTNG

Tangent subspace projection

Invariance learning on manifolds

RiemannianNeuralGas

Geodesic distance

Unsupervised manifold clustering

Supported Manifold Selection

Manifold

Geometry

Use When

SO(n)

Rotation matrices

Pose estimation, robotics, structural biology

SPD(n)

Positive definite matrices

Covariance matrices (EEG/BCI, DTI)

Grassmannian(n, k)

Subspaces of \(\mathbb{R}^n\)

Video/image subspace analysis

HyperbolicPoincare(d)

Poincare ball \(\mathbb{B}^d\)

Hierarchical/tree-structured data