Riemannian Models¶
Prosemble provides prototype-based learning on Riemannian manifolds. Prototypes live directly on the manifold and distances are computed via the intrinsic geodesic metric, preserving the geometric structure of the data.
Supported Manifolds¶
Manifold |
Points |
Applications |
|---|---|---|
SO(n) |
\(n \times n\) rotation matrices |
Robotics (grasp/pose), 3D vision, structural biology |
SPD(n) |
\(n \times n\) symmetric positive definite matrices |
EEG/BCI (covariance matrices), diffusion tensor imaging |
Gr(n, k) |
\(k\)-dimensional subspaces of \(\mathbb{R}^n\) |
Hyperspectral imaging, video analysis, subspace tracking |
Manifold |
Points |
Applications |
|---|---|---|
HyperbolicPoincare(d) |
Vectors in \(\mathbb{B}^d = \{x \in \mathbb{R}^d : \|x\| < 1\}\) |
Hierarchical data, taxonomy embeddings, NLP trees |
All four are available from prosemble.core.manifolds:
from prosemble.core.manifolds import SO, SPD, Grassmannian, HyperbolicPoincare
so3 = SO(3) # 3x3 rotation matrices
spd4 = SPD(4) # 4x4 SPD matrices
gr5_2 = Grassmannian(5, 2) # 2D subspaces of R^5
hyp8 = HyperbolicPoincare(8) # 8D Poincare ball
HyperbolicPoincare¶
The Poincare ball model of hyperbolic geometry. Points live in the open unit ball \(\mathbb{B}^d = \{x \in \mathbb{R}^d : \|x\| < 1\}\) with the Riemannian metric tensor \(g_x = \lambda_x^2 I\) where \(\lambda_x = 2/(1 - \|x\|^2)\) is the conformal factor.
The geodesic distance is:
Prototype updates use the exponential and logarithmic maps based on Mobius addition:
Hyperbolic space is particularly suited for data with hierarchical or tree-like structure, as distances grow exponentially toward the boundary.
from prosemble.models import RiemannianSRNG
from prosemble.core.manifolds import HyperbolicPoincare
import jax
import jax.numpy as jnp
manifold = HyperbolicPoincare(4)
# Generate data inside the Poincare ball
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (40, 4)) * 0.3
X = jax.vmap(manifold.project)(X) # ensure points lie in ball
X = X.reshape(40, -1)
y = jnp.array([0] * 20 + [1] * 20)
model = RiemannianSRNG(
manifold=manifold,
n_prototypes_per_class=2,
max_iter=50,
lr=0.01,
)
model.fit(X, y)
labels = model.predict(X)
All supervised Riemannian models (RiemannianSRNG, RiemannianSMNG,
RiemannianSLNG, RiemannianSTNG) and the unsupervised RiemannianNeuralGas
work with HyperbolicPoincare without any code changes.
RiemannianSRNG¶
Supervised Riemannian Neural Gas. Combines GLVQ-style margin-based classification with Neural Gas neighbourhood cooperation on the manifold.
The loss for each sample uses all same-class prototypes, rank-weighted by \(h(k) = \exp(-k / \gamma)\):
where \(d^+, d^-\) are the nearest same-class and different-class geodesic distances.
from prosemble.models import RiemannianSRNG
from prosemble.core.manifolds import SO
import jax.numpy as jnp
# Generate synthetic rotation data
key = jax.random.PRNGKey(0)
manifold = SO(3)
X = jnp.stack([manifold.random_point(jax.random.fold_in(key, i))
for i in range(40)])
X = X.reshape(40, -1) # flatten to (n_samples, 9)
y = jnp.array([0] * 20 + [1] * 20)
model = RiemannianSRNG(
manifold=manifold,
n_prototypes_per_class=2,
max_iter=50,
lr=0.01,
gamma_final=0.01,
)
model.fit(X, y)
labels = model.predict(X)
RiemannianSMNG¶
Supervised Matrix Neural Gas on manifolds. Adds a global metric adaptation matrix \(\Omega\) that operates in the tangent space at each prototype:
where \(\text{Log}_{w_k}\) is the Riemannian logarithmic map (tangent vector from \(w_k\) to \(x\)).
from prosemble.models import RiemannianSMNG
from prosemble.core.manifolds import SPD
manifold = SPD(3)
model = RiemannianSMNG(
manifold=manifold,
n_prototypes_per_class=2,
latent_dim=4,
max_iter=50,
lr=0.01,
)
model.fit(X, y)
# Learned relevance matrix
print(model.relevance_matrix().shape)
RiemannianSLNG¶
Supervised Localized Matrix Neural Gas. Each prototype \(w_k\) has its own metric matrix \(\Omega_k\):
This allows different prototypes to focus on different tangent directions, useful when the discriminative structure varies across the manifold.
from prosemble.models import RiemannianSLNG
from prosemble.core.manifolds import Grassmannian
manifold = Grassmannian(5, 2)
model = RiemannianSLNG(
manifold=manifold,
n_prototypes_per_class=2,
latent_dim=3,
max_iter=50,
lr=0.01,
)
model.fit(X, y)
RiemannianSTNG¶
Supervised Tangent Neural Gas. Each prototype has a tangent subspace \(\Omega_k\), and distance is measured in the complement of that subspace (the residual after projection):
This captures invariance structure — directions spanned by \(\Omega_k\) are ignored.
from prosemble.models import RiemannianSTNG
from prosemble.core.manifolds import SO
manifold = SO(3)
model = RiemannianSTNG(
manifold=manifold,
n_prototypes_per_class=2,
subspace_dim=2,
max_iter=50,
lr=0.01,
)
model.fit(X, y)
RiemannianNeuralGas¶
Unsupervised Neural Gas on Riemannian manifolds. Distributes prototypes to match the data density using rank-based cooperation with geodesic distances and exponential map updates.
from prosemble.models import RiemannianNeuralGas
from prosemble.core.manifolds import SO
manifold = SO(3)
model = RiemannianNeuralGas(
manifold=manifold,
n_prototypes=5,
max_iter=50,
lr_init=0.3,
lr_final=0.01,
lambda_final=0.01,
)
model.fit(X)
labels = model.predict(X) # nearest prototype assignment
distances = model.transform(X) # geodesic distance matrix
Note
All Riemannian models use Python loops (not lax.scan) because
manifold projection after each gradient step is not compatible with
JAX’s functional loop primitives.
Choosing a Model¶
Model |
Metric |
Best For |
|---|---|---|
RiemannianSRNG |
Geodesic distance |
General manifold classification |
RiemannianSMNG |
Global \(\Omega\) in tangent space |
Feature selection on manifolds |
RiemannianSLNG |
Per-prototype \(\Omega_k\) |
Heterogeneous tangent structure |
RiemannianSTNG |
Tangent subspace projection |
Invariance learning on manifolds |
RiemannianNeuralGas |
Geodesic distance |
Unsupervised manifold clustering |
Manifold |
Geometry |
Use When |
|---|---|---|
SO(n) |
Rotation matrices |
Pose estimation, robotics, structural biology |
SPD(n) |
Positive definite matrices |
Covariance matrices (EEG/BCI, DTI) |
Grassmannian(n, k) |
Subspaces of \(\mathbb{R}^n\) |
Video/image subspace analysis |
HyperbolicPoincare(d) |
Poincare ball \(\mathbb{B}^d\) |
Hierarchical/tree-structured data |