Source code for prosemble.models.fpcm

"""
JAX-based Fuzzy Possibilistic C-Means (FPCM) clustering implementation.

This module provides a GPU-accelerated implementation of FPCM using JAX
with JIT compilation for high performance.
"""

from typing import NamedTuple, Self
from functools import partial

import jax
import jax.numpy as jnp
import chex
from jax import jit

from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
from prosemble.models.fcm import FCM


class FPCMState(NamedTuple):
    """Immutable state for FPCM iteration.

    Attributes:
        centroids: Cluster centroids, shape (n_clusters, n_features)
        U: Fuzzy membership matrix, shape (n_samples, n_clusters)
        T: Possibilistic typicality matrix, shape (n_samples, n_clusters)
        objective: Current objective function value
        iteration: Current iteration number
        converged: Whether algorithm has converged
    """
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool


[docs] class FPCM(ScanFitMixin, FuzzyClusteringBase): """ Fuzzy Possibilistic C-Means clustering with JAX. FPCM maintains TWO matrices: :math:`U` (fuzzy membership) and :math:`T` (typicality). :math:`U` has row-sum-to-1 constraint (standard FCM), while :math:`T` has column-sum-to-1 constraint per the original Pal, Pal & Bezdek (1997) formulation. Algorithm: 1. Initialize :math:`U` and :math:`T` (randomly or using FCM) 2. Update centroids using combined fuzzy and typicality weights 3. Update :math:`U` using FCM rule with fuzzifier :math:`m` (row-normalized) 4. Update :math:`T` with column-normalization 5. Repeat until convergence Objective function: .. math:: J = \\sum_i \\sum_j \\left[u_{ij}^m + t_{ij}^\\eta\\right] \\|x_i - v_j\\|^2 Reference: Pal, N. R., Pal, K., & Bezdek, J. C. (1997). A mixed c-means clustering model. FUZZ-IEEE. Parameters ---------- fuzzifier : float, default=2.0 Fuzziness parameter for :math:`U` matrix (must be > 1.0). eta : float, default=2.0 Fuzziness parameter for :math:`T` matrix (must be > 1.0). init_method : {'random', 'fcm'}, default='fcm' Method for initializing :math:`U` and :math:`T` matrices. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. Attributes ---------- centroids_ : array, shape (n_clusters, n_features) Final cluster centroids U_ : array, shape (n_samples, n_clusters) Final fuzzy membership matrix T_ : array, shape (n_samples, n_clusters) Final possibilistic typicality matrix n_iter_ : int Number of iterations until convergence objective_ : float Final objective function value objective_history_ : array Objective value at each iteration Examples -------- >>> import jax.numpy as jnp >>> from prosemble.models import FPCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = FPCM(n_clusters=2, fuzzifier=2.0, eta=2.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X) >>> U = model.predict_proba(X) >>> T = model.get_typicality(X) """ _hyperparams = ('fuzzifier', 'eta', 'init_method') _fitted_array_names = ('U_', 'T_') def __init__( self, n_clusters: int, fuzzifier: float = 2.0, eta: float = 2.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'fcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None, ): # Validate model-specific parameters if fuzzifier <= 1.0: raise ValueError("fuzzifier must be > 1.0") if eta <= 1.0: raise ValueError("eta must be > 1.0") if init_method not in ['random', 'fcm']: raise ValueError("init_method must be 'random' or 'fcm'") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.eta = eta self.init_method = init_method # Model-specific fitted attributes self.U_ = None self.T_ = None def _initialize_matrices(self, X: chex.Array): """Initialize :math:`U` and :math:`T` matrices. Args: X: Input data, shape (n_samples, n_features) Returns: Tuple of (:math:`U`, :math:`T`, centroids) """ n_samples = X.shape[0] if self.init_method == 'random': # Random initialization (Dirichlet distribution ensures row sums = 1) alpha = jnp.ones(self.n_clusters) U = jax.random.dirichlet(self.key, alpha, shape=(n_samples,)) T = jax.random.dirichlet(self.key, alpha, shape=(n_samples,)) # Compute initial centroids centroids = self._compute_centroids(X, U, T) elif self.init_method == 'fcm': # Initialize using FCM fcm = FCM( n_clusters=self.n_clusters, fuzzifier=self.fuzzifier, max_iter=self.max_iter, epsilon=self.epsilon, random_seed=self.random_seed, distance_fn=self.distance_fn, plot_steps=False ) fcm.fit(X) U = fcm.U_ centroids = fcm.centroids_ # Initialize T randomly alpha = jnp.ones(self.n_clusters) T = jax.random.dirichlet(self.key, alpha, shape=(n_samples,)) else: raise ValueError(f"Unknown init_method: {self.init_method}") return U, T, centroids @partial(jit, static_argnums=(0,)) def _compute_centroids( self, X: chex.Array, U: chex.Array, T: chex.Array ) -> chex.Array: """Compute cluster centroids. .. math:: v_j = \\frac{\\sum_i \\left[u_{ij}^m + t_{ij}^\\eta\\right] x_i}{\\sum_i \\left[u_{ij}^m + t_{ij}^\\eta\\right]} Args: X: Input data, shape (n_samples, n_features) U: Fuzzy membership matrix, shape (n_samples, n_clusters) T: Typicality matrix, shape (n_samples, n_clusters) Returns: centroids: shape (n_clusters, n_features) """ # Fuzzify U and T U_fuzz = jnp.power(U, self.fuzzifier) # (n_samples, n_clusters) T_fuzz = jnp.power(T, self.eta) # (n_samples, n_clusters) # Combined weights weights = U_fuzz + T_fuzz # (n_samples, n_clusters) # Compute centroids numerator = weights.T @ X # (n_clusters, n_features) denominator = jnp.sum(weights, axis=0, keepdims=True).T # (n_clusters, 1) centroids = numerator / denominator return centroids @partial(jit, static_argnums=(0,)) def _update_fuzzy_matrix( self, X: chex.Array, centroids: chex.Array, fuzzifier: float ) -> chex.Array: """Update fuzzy membership matrix using FCM rule. Standard FCM update: .. math:: u_{ij} = \\frac{1}{\\sum_k \\left(\\frac{d_{ij}}{d_{ik}}\\right)^{2/(m-1)}} Args: X: Input data, shape (n_samples, n_features) centroids: Current centroids, shape (n_clusters, n_features) fuzzifier: Fuzziness parameter Returns: U: Updated membership matrix, shape (n_samples, n_clusters) """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) # Add small epsilon to avoid division by zero D_sq = jnp.maximum(D_sq, 1e-10) # Compute power for FCM update power = 2.0 / (fuzzifier - 1.0) # Compute distance ratios: (d_ij / d_ik)^power for all k # For each i,j: sum over k of (D[i,j] / D[i,k])^power def compute_membership_row(distances_i): # distances_i: (n_clusters,) # For each j, compute: 1 / sum_k (d_ij / d_ik)^power ratios = distances_i[:, None] / distances_i[None, :] # (n_clusters, n_clusters) powered_ratios = jnp.power(ratios, power) # (n_clusters, n_clusters) denominators = jnp.sum(powered_ratios, axis=1) # (n_clusters,) memberships = 1.0 / denominators return memberships U = jax.vmap(compute_membership_row)(D_sq) # (n_samples, n_clusters) # Normalize to ensure row sums = 1 (for numerical stability) U = U / jnp.sum(U, axis=1, keepdims=True) return U @partial(jit, static_argnums=(0,)) def _update_typicality_matrix( self, X: chex.Array, centroids: chex.Array ) -> chex.Array: """Update typicality matrix with column-sum-to-1 constraint. Per Pal, Pal & Bezdek (1997): .. math:: t_{ij} = \\frac{(1/d_{ij}^2)^{1/(\\eta-1)}}{\\sum_i (1/d_{ij}^2)^{1/(\\eta-1)}} Each column j sums to 1 across samples (:math:`\\sum_i t_{ij} = 1`). Args: X: Input data, shape (n_samples, n_features) centroids: Current centroids, shape (n_clusters, n_features) Returns: T: Typicality matrix, shape (n_samples, n_clusters) """ D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) D_sq = jnp.maximum(D_sq, 1e-10) power = 1.0 / (self.eta - 1.0) inv_dist_powered = jnp.power(1.0 / D_sq, power) # (n_samples, n_clusters) # Normalize over samples (axis=0) so each column sums to 1 col_sums = jnp.maximum(jnp.sum(inv_dist_powered, axis=0, keepdims=True), 1e-10) T = inv_dist_powered / col_sums return T @partial(jit, static_argnums=(0,)) def _compute_objective( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute FPCM objective function. .. math:: J = \\sum_i \\sum_j \\left[u_{ij}^m + t_{ij}^\\eta\\right] \\|x_i - v_j\\|^2 Args: X: Input data, shape (n_samples, n_features) U: Fuzzy membership matrix, shape (n_samples, n_clusters) T: Typicality matrix, shape (n_samples, n_clusters) centroids: Current centroids, shape (n_clusters, n_features) Returns: objective: Scalar objective value """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) # Fuzzify U and T U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.eta) # Combined weights weights = U_fuzz + T_fuzz # Weighted distances weighted_distances = weights * D_sq # Sum over all elements objective = jnp.sum(weighted_distances) return objective @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: FPCMState, X: chex.Array) -> tuple[FPCMState, dict]: """Single FPCM iteration step. Args: state: Current FPCM state X: Input data, shape (n_samples, n_features) Returns: new_state: Updated FPCM state metrics: Dictionary of iteration metrics """ # Update U with fuzzifier m (row-normalized) U_new = self._update_fuzzy_matrix(X, state.centroids, self.fuzzifier) # Update T with column-normalization (Pal et al. 1997) T_new = self._update_typicality_matrix(X, state.centroids) # Update centroids centroids_new = self._compute_centroids(X, U_new, T_new) # Compute objective objective = self._compute_objective(X, U_new, T_new, centroids_new) # Check convergence based on centroid change centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro') converged = centroid_change <= self.epsilon new_state = FPCMState( centroids=centroids_new, U=U_new, T=T_new, objective=objective, iteration=state.iteration + 1, converged=converged ) metrics = { 'objective': new_state.objective, 'centroid_change': centroid_change, 'converged': new_state.converged, } return new_state, metrics def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = (jnp.max(state.U, axis=1) + jnp.max(state.T, axis=1)) / 2.0 return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, }
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: """Fit FPCM model to data. Args: X: Input data, shape (n_samples, n_features) initial_centroids: Optional initial centroids for warm starting resume: If True, resume from fitted state Returns: self: Fitted model Raises: ValueError: If n_samples < n_clusters """ if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") X = self._validate_input(X) if resume: self._check_fitted() centroids_init = self.centroids_ U_init = self._update_fuzzy_matrix(X, centroids_init, self.fuzzifier) T_init = self._update_typicality_matrix(X, centroids_init) elif initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) U_init = self._update_fuzzy_matrix(X, centroids_init, self.fuzzifier) T_init = self._update_typicality_matrix(X, centroids_init) else: U_init, T_init, centroids_init = self._initialize_matrices(X) initial_objective = self._compute_objective(X, U_init, T_init, centroids_init) initial_state = FPCMState( centroids=centroids_init, U=U_init, T=T_init, objective=initial_objective, iteration=0, converged=False ) final_state, self.history_ = self._run_training(X, initial_state) # Store results self.centroids_ = final_state.centroids self.U_ = final_state.U self.T_ = final_state.T self.n_iter_ = int(final_state.iteration) self.objective_ = float(final_state.objective) self.objective_history_ = self.history_['objective'] return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """Predict cluster labels for new data. Args: X: Input data, shape (n_samples, n_features) Returns: labels: Cluster labels, shape (n_samples,) Raises: ValueError: If model has not been fitted """ self._check_fitted() # Compute U for new data U = self._update_fuzzy_matrix(X, self.centroids_, self.fuzzifier) # Assign to cluster with highest membership labels = jnp.argmax(U, axis=1) return labels
[docs] def predict_proba(self, X: chex.Array) -> chex.Array: """Predict fuzzy membership probabilities (U matrix). Args: X: Input data, shape (n_samples, n_features) Returns: U: Fuzzy membership matrix, shape (n_samples, n_clusters) Raises: ValueError: If model has not been fitted """ self._check_fitted() X = jnp.asarray(X) U = self._update_fuzzy_matrix(X, self.centroids_, self.fuzzifier) return U
[docs] def get_typicality(self, X: chex.Array) -> chex.Array: """Compute typicality values (T matrix). Args: X: Input data, shape (n_samples, n_features) Returns: T: Typicality matrix, shape (n_samples, n_clusters) Raises: ValueError: If model has not been fitted """ self._check_fitted() X = jnp.asarray(X) T = self._update_typicality_matrix(X, self.centroids_) return T