Source code for prosemble.core.geodesic

"""Geodesic interpolation and boundary visualization for Riemannian models.

Computes geodesic paths between prototypes on Riemannian manifolds for
decision boundary visualization and model interpretation.

On curved manifolds, the decision boundary (locus where d(x, w_i) = d(x, w_j))
is not a hyperplane — it's a curved surface. Geodesic interpolation allows
us to:
1. Visualize the path between prototypes along the manifold.
2. Find the approximate decision boundary point along geodesics.
3. Compute prototype midpoints respecting manifold geometry.

References
----------
.. [1] Absil, P.-A., Mahony, R., & Sepulchre, R. (2008). Optimization
       Algorithms on Matrix Manifolds. Princeton University Press.
"""

import jax.numpy as jnp


[docs] def geodesic_interpolation(manifold, point_a, point_b, n_points=50): """Compute a geodesic path between two points on a manifold. Uses the exponential map to trace the shortest path (geodesic) between two manifold points: .. math:: \\gamma(t) = \\text{Exp}_{w_a}(t \\cdot \\text{Log}_{w_a}(w_b)), \\quad t \\in [0, 1] Parameters ---------- manifold : SO, SPD, or Grassmannian Riemannian manifold instance with ``exp_map`` and ``log_map``. point_a : array Starting point on the manifold. point_b : array End point on the manifold. n_points : int Number of interpolation points along the geodesic. Default: 50. Returns ------- path : array of shape (n_points, ...) Points along the geodesic from point_a to point_b. Each point lies on the manifold. """ # Compute the tangent vector from a to b tangent = manifold.log_map(point_a, point_b) # Interpolate along the geodesic t_values = jnp.linspace(0.0, 1.0, n_points) path = [] for t in t_values: point = manifold.exp_map(point_a, t * tangent) path.append(point) return jnp.stack(path, axis=0)
[docs] def geodesic_midpoint(manifold, point_a, point_b): """Compute the geodesic midpoint between two manifold points. The midpoint is at t = 0.5 along the geodesic: .. math:: m = \\text{Exp}_{w_a}(0.5 \\cdot \\text{Log}_{w_a}(w_b)) Parameters ---------- manifold : SO, SPD, or Grassmannian Riemannian manifold instance. point_a : array First manifold point. point_b : array Second manifold point. Returns ------- midpoint : array Geodesic midpoint, guaranteed to lie on the manifold. """ tangent = manifold.log_map(point_a, point_b) return manifold.exp_map(point_a, 0.5 * tangent)
[docs] def decision_boundary_point(manifold, proto_a, proto_b, n_search=100): """Find the decision boundary point along the geodesic between prototypes. On curved manifolds, the equidistant point (where d(x, w_a) = d(x, w_b)) is not necessarily at t = 0.5. This function searches along the geodesic for the point where distances to both prototypes are equal. Parameters ---------- manifold : SO, SPD, or Grassmannian Riemannian manifold instance. proto_a : array First prototype (on manifold). proto_b : array Second prototype (on manifold). n_search : int Number of candidate points to evaluate. Default: 100. Returns ------- boundary_point : array Point on the geodesic where d(point, proto_a) = d(point, proto_b). t_boundary : float Parameter value t in [0, 1] where the boundary lies. t = 0.5 for symmetric manifolds, may differ on curved spaces. """ tangent = manifold.log_map(proto_a, proto_b) t_values = jnp.linspace(0.0, 1.0, n_search) # Compute distance difference at each point along the geodesic def distance_diff(t): point = manifold.exp_map(proto_a, t * tangent) d_a = manifold.distance(point, proto_a) d_b = manifold.distance(point, proto_b) return jnp.abs(d_a - d_b) diffs = jnp.array([distance_diff(t) for t in t_values]) best_idx = jnp.argmin(diffs) t_boundary = t_values[best_idx] boundary_point = manifold.exp_map(proto_a, t_boundary * tangent) return boundary_point, float(t_boundary)
[docs] def prototype_geodesic_distances(manifold, prototypes, proto_labels): """Compute pairwise geodesic distances between prototypes. Parameters ---------- manifold : SO, SPD, or Grassmannian Riemannian manifold instance. prototypes : array of shape (n_prototypes, ...) Prototype points on the manifold (flattened). proto_labels : array of shape (n_prototypes,) Class labels for prototypes. Returns ------- distances : array of shape (n_prototypes, n_prototypes) Pairwise geodesic distance matrix. """ n = prototypes.shape[0] distances = jnp.zeros((n, n)) for i in range(n): for j in range(i + 1, n): d = manifold.distance(prototypes[i], prototypes[j]) distances = distances.at[i, j].set(d) distances = distances.at[j, i].set(d) return distances
[docs] def inter_class_geodesics(manifold, prototypes, proto_labels, n_points=50): """Compute geodesic paths between all inter-class prototype pairs. Useful for visualizing decision boundaries: the boundary between two classes lies somewhere along the geodesic connecting their closest prototypes. Parameters ---------- manifold : SO, SPD, or Grassmannian Riemannian manifold instance. prototypes : array of shape (n_prototypes, ...) Prototype positions on manifold. proto_labels : array of shape (n_prototypes,) Class labels. n_points : int Points per geodesic path. Default: 50. Returns ------- geodesics : list of dict Each dict contains: - 'path': array of shape (n_points, ...) — geodesic path - 'proto_a_idx': int — index of first prototype - 'proto_b_idx': int — index of second prototype - 'class_a': int — class of first prototype - 'class_b': int — class of second prototype - 'boundary_t': float — approximate boundary location """ import numpy as np n = prototypes.shape[0] labels_np = np.asarray(proto_labels) geodesics = [] for i in range(n): for j in range(i + 1, n): if labels_np[i] != labels_np[j]: path = geodesic_interpolation( manifold, prototypes[i], prototypes[j], n_points ) _, t_boundary = decision_boundary_point( manifold, prototypes[i], prototypes[j] ) geodesics.append({ 'path': path, 'proto_a_idx': i, 'proto_b_idx': j, 'class_a': int(labels_np[i]), 'class_b': int(labels_np[j]), 'boundary_t': t_boundary, }) return geodesics