"""
Structural typing protocols for prosemble interfaces.
Defines ``typing.Protocol`` contracts for duck-typed interfaces used
across the library. These enable static type checking (mypy / pyright)
and IDE auto-completion without requiring inheritance.
.. note::
Runtime-checkable protocols (``Manifold``, ``CallbackLike``) support
``isinstance()`` checks. Type aliases (``DistanceMatrixFn``, etc.)
are for annotation only.
"""
from __future__ import annotations
from typing import (
Any,
Callable,
Protocol,
Tuple,
runtime_checkable,
)
import jax
import jax.numpy as jnp
# ---------------------------------------------------------------------------
# Manifold protocol
# ---------------------------------------------------------------------------
[docs]
@runtime_checkable
class Manifold(Protocol):
"""Protocol for Riemannian manifold implementations.
Any object exposing the methods below can be used wherever a manifold
is expected (e.g. ``RiemannianNeuralGas``, ``RiemannianSRNG``).
The concrete implementations :class:`~prosemble.core.manifolds.SO`,
:class:`~prosemble.core.manifolds.SPD`, and
:class:`~prosemble.core.manifolds.Grassmannian` all satisfy this
protocol structurally — no explicit subclassing is required.
"""
@property
def point_shape(self) -> Tuple[int, ...]:
"""Shape of a single point on the manifold."""
...
[docs]
def distance(self, p: jnp.ndarray, q: jnp.ndarray) -> jnp.ndarray:
"""Geodesic distance between two points.
Parameters
----------
p, q : arrays of shape ``point_shape``
Returns
-------
scalar
"""
...
[docs]
def distance_squared(self, p: jnp.ndarray, q: jnp.ndarray) -> jnp.ndarray:
"""Squared geodesic distance between two points.
Parameters
----------
p, q : arrays of shape ``point_shape``
Returns
-------
scalar
"""
...
[docs]
def log_map(self, base: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray:
"""Logarithmic map: tangent vector at *base* pointing toward *target*.
Parameters
----------
base : array of shape ``point_shape``
Base point on the manifold.
target : array of shape ``point_shape``
Target point on the manifold.
Returns
-------
tangent : array of shape ``point_shape``
Tangent vector in :math:`T_{\\text{base}} M`.
"""
...
[docs]
def exp_map(self, base: jnp.ndarray, tangent: jnp.ndarray) -> jnp.ndarray:
"""Exponential map: move along *tangent* from *base* back to the manifold.
Parameters
----------
base : array of shape ``point_shape``
tangent : array of shape ``point_shape``
Returns
-------
point : array of shape ``point_shape``
"""
...
[docs]
def random_point(self, key: jax.Array) -> jnp.ndarray:
"""Sample a random point on the manifold.
Parameters
----------
key : JAX PRNG key
Returns
-------
point : array of shape ``point_shape``
"""
...
[docs]
def belongs(self, point: jnp.ndarray) -> jnp.ndarray:
"""Check whether *point* lies on the manifold.
Parameters
----------
point : array of shape ``point_shape``
Returns
-------
bool or bool-valued array
"""
...
[docs]
def project(self, point: jnp.ndarray) -> jnp.ndarray:
"""Project an off-manifold point to the nearest point on the manifold.
Parameters
----------
point : array of shape ``point_shape``
Returns
-------
projected : array of shape ``point_shape``
"""
...
[docs]
def injectivity_radius(self, point: jnp.ndarray) -> float:
"""Injectivity radius at *point*.
The maximum geodesic distance for which the logarithmic map
is injective.
Parameters
----------
point : array of shape ``point_shape``
Returns
-------
radius : float or scalar array
"""
...
# ---------------------------------------------------------------------------
# Callback protocol
# ---------------------------------------------------------------------------
[docs]
@runtime_checkable
class CallbackLike(Protocol):
"""Protocol for training callbacks.
Any object with the three hook methods below can be passed in the
``callbacks`` list of a model's constructor. The existing
:class:`~prosemble.core.callbacks.Callback` base class already
satisfies this protocol.
"""
[docs]
def on_fit_start(self, model: Any, X: jnp.ndarray) -> None:
"""Called once before training begins."""
...
[docs]
def on_iteration_end(self, model: Any, info: dict) -> None:
"""Called after each training iteration / epoch."""
...
[docs]
def on_fit_end(self, model: Any, info: dict) -> None:
"""Called once after training ends."""
...
# ---------------------------------------------------------------------------
# Type aliases for callable interfaces
# ---------------------------------------------------------------------------
#: Distance-matrix function: ``(X, Y) -> distances``.
#: ``X`` has shape ``(n_samples, n_features)``,
#: ``Y`` has shape ``(n_prototypes, n_features)``,
#: result has shape ``(n_samples, n_prototypes)``.
DistanceMatrixFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
#: Pairwise distance function: ``(x, y) -> scalar``.
DistancePairwiseFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
#: Supervised prototype initializer:
#: ``(X, y, n_per_class, key) -> (prototypes, prototype_labels)``.
SupervisedInitFn = Callable[
[jnp.ndarray, jnp.ndarray, int, jax.Array],
Tuple[jnp.ndarray, jnp.ndarray],
]
#: Unsupervised prototype initializer:
#: ``(X, n_prototypes, key) -> prototypes``.
UnsupervisedInitFn = Callable[
[jnp.ndarray, int, jax.Array],
jnp.ndarray,
]