Source code for prosemble.models.ipcm

"""
JAX-based Improved Possibilistic C-Means (IPCM) clustering implementation.

This module provides a GPU-accelerated implementation of IPCM using JAX
with JIT compilation for high performance.
"""

from typing import NamedTuple, Self
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import chex
from jax import jit, lax

from prosemble.models.base import FuzzyClusteringBase
from prosemble.models.fcm import FCM


class IPCMState(NamedTuple):
    """Immutable state for IPCM iteration.

    Attributes:
        centroids: Cluster centroids, shape (n_clusters, n_features)
        U: Fuzzy membership matrix, shape (n_samples, n_clusters)
        T: Typicality matrix, shape (n_samples, n_clusters)
        gamma: Scale parameters, shape (n_clusters,)
        objective: Current objective function value
        iteration: Current iteration number
        converged: Whether algorithm has converged
        phase: Current phase (0 or 1)
    """
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    gamma: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool
    phase: int


[docs] class IPCM(FuzzyClusteringBase): """ Improved Possibilistic C-Means clustering with JAX. IPCM uses a two-phase approach to improve clustering performance: - Phase 0: Initialize :math:`\\gamma` using fuzzy membership only - Phase 1: Refine :math:`\\gamma` using both membership and typicality Key differences from PCM: - Uses product of :math:`U^{m_f}` and :math:`T^{m_p}` in centroid computation - Modified :math:`U` update that depends on :math:`T` - Two-phase :math:`\\gamma` computation Algorithm (Phase 0): 1. Initialize :math:`U` using FCM, :math:`T = 0` 2. Compute :math:`\\gamma` parameters from fuzzy membership 3. Update typicality matrix :math:`T` 4. Update membership matrix :math:`U` 5. Update centroids using combined U and T weights 6. Repeat until convergence Algorithm (Phase 1): 7. Recompute :math:`\\gamma` using both :math:`U` and :math:`T` 8. Continue iterations with new gamma Objective function: .. math:: J = \\sum_i \\sum_j u_{ij}^{m_f} \\cdot t_{ij}^{m_p} \\cdot d_{ij}^2 + \\sum_j \\gamma_j \\sum_i (1 - t_{ij})^{m_p} \\cdot u_{ij}^{m_f} Parameters ---------- fuzzifier : float, default=2.0 Fuzziness parameter for :math:`U` matrix (:math:`m_f`, must be > 1.0). tipifier : float, default=2.0 Possibilistic parameter for :math:`T` matrix (:math:`m_p`, must be > 1.0). k : float, default=1.0 Scaling parameter for :math:`\\gamma` in phase 1 (must be > 0). init_method : {'fcm'}, default='fcm' Method for initializing :math:`U` matrix (must use FCM). n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. Attributes ---------- centroids_ : array, shape (n_clusters, n_features) Final cluster centroids U_ : array, shape (n_samples, n_clusters) Final fuzzy membership matrix T_ : array, shape (n_samples, n_clusters) Final typicality matrix gamma_ : array, shape (n_clusters,) Final scale parameters n_iter_ : int Total number of iterations across both phases objective_ : float Final objective function value objective_history_ : array Objective value at each iteration Examples -------- >>> import jax.numpy as jnp >>> from prosemble.models import IPCM >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = IPCM(n_clusters=2, fuzzifier=2.0, tipifier=2.0, k=1.0, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X) """ _hyperparams = ('fuzzifier', 'tipifier', 'k', 'init_method') _fitted_array_names = ('U_', 'T_', 'gamma_') def __init__( self, n_clusters: int, fuzzifier: float = 2.0, tipifier: float = 2.0, k: float = 1.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'fcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None, ): # Validate model-specific parameters if fuzzifier <= 1.0: raise ValueError("fuzzifier must be > 1.0") if tipifier <= 1.0: raise ValueError("tipifier must be > 1.0") if k <= 0: raise ValueError("k must be > 0") if init_method != 'fcm': raise ValueError("init_method must be 'fcm' for IPCM") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.tipifier = tipifier self.k = k self.init_method = init_method # Model-specific fitted attributes self.U_ = None self.T_ = None self.gamma_ = None def _initialize_phase0(self, X: chex.Array): """Initialize for phase 0 using FCM. Args: X: Input data, shape (n_samples, n_features) Returns: Tuple of (U, T, centroids) """ n_samples = X.shape[0] # Initialize using FCM fcm = FCM( n_clusters=self.n_clusters, fuzzifier=self.fuzzifier, max_iter=self.max_iter, epsilon=self.epsilon, random_seed=self.random_seed, distance_fn=self.distance_fn, plot_steps=False ) fcm.fit(X) U = fcm.U_ centroids = fcm.centroids_ # Initialize T as zeros T = jnp.zeros((n_samples, self.n_clusters)) return U, T, centroids @partial(jit, static_argnums=(0,)) def _compute_gamma_phase0( self, X: chex.Array, U: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute :math:`\\gamma` for phase 0. .. math:: \\gamma_j = \\frac{\\sum_i u_{ij}^{m_f} \\cdot d_{ij}^2}{\\sum_i u_{ij}^{m_f}} Args: X: Input data, shape (n_samples, n_features) U: Fuzzy membership matrix, shape (n_samples, n_clusters) centroids: Current centroids, shape (n_clusters, n_features) Returns: gamma: shape (n_clusters,) """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) # Fuzzify U U_fuzz = jnp.power(U, self.fuzzifier) # (n_samples, n_clusters) # Compute gamma for each cluster numerator = jnp.sum(U_fuzz * D_sq, axis=0) # (n_clusters,) denominator = jnp.maximum(jnp.sum(U_fuzz, axis=0), 1e-10) # (n_clusters,) gamma = numerator / denominator return gamma @partial(jit, static_argnums=(0,)) def _compute_gamma_phase1( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute :math:`\\gamma` for phase 1. .. math:: \\gamma_j = k \\cdot \\frac{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}^{m_p} \\cdot d_{ij}^2}{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}^{m_p}} Args: X: Input data, shape (n_samples, n_features) U: Fuzzy membership matrix, shape (n_samples, n_clusters) T: Typicality matrix, shape (n_samples, n_clusters) centroids: Current centroids, shape (n_clusters, n_features) Returns: gamma: shape (n_clusters,) """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) # Fuzzify U and T U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) # Product of memberships prod = U_fuzz * T_fuzz # (n_samples, n_clusters) # Compute gamma for each cluster numerator = jnp.sum(prod * D_sq, axis=0) # (n_clusters,) denominator = jnp.maximum(jnp.sum(prod, axis=0), 1e-10) # (n_clusters,) gamma = self.k * numerator / denominator return gamma @partial(jit, static_argnums=(0,)) def _update_T( self, X: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Update typicality matrix. .. math:: t_{ij} = \\frac{1}{1 + \\left(\\frac{d_{ij}^2}{\\gamma_j}\\right)^{1/(m_p-1)}} Args: X: Input data, shape (n_samples, n_features) centroids: Current centroids, shape (n_clusters, n_features) gamma: Scale parameters, shape (n_clusters,) Returns: T: Updated typicality matrix, shape (n_samples, n_clusters) """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) D_sq = jnp.maximum(D_sq, 1e-10) # Avoid division by zero # Compute power power = 1.0 / (self.tipifier - 1.0) # Compute ratio and typicality ratio = D_sq / gamma[None, :] # (n_samples, n_clusters) T = 1.0 / (1.0 + jnp.power(ratio, power)) return T @partial(jit, static_argnums=(0,)) def _update_U( self, X: chex.Array, T: chex.Array, centroids: chex.Array ) -> chex.Array: """Update fuzzy membership matrix (IPCM-specific). .. math:: u_{ij} = \\frac{\\left(\\frac{1}{d_{ij}^2} \\cdot t_{ij}^{m_p-1}\\right)^{1/(m_f-1)}}{\\sum_k \\left(\\frac{1}{d_{ik}^2} \\cdot t_{ik}^{m_p-1}\\right)^{1/(m_f-1)}} Args: X: Input data, shape (n_samples, n_features) T: Typicality matrix, shape (n_samples, n_clusters) centroids: Current centroids, shape (n_clusters, n_features) Returns: U: Updated membership matrix, shape (n_samples, n_clusters) """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) D_sq = jnp.maximum(D_sq, 1e-10) # Avoid division by zero # Compute T^(m_p-1) T_pow = jnp.power(T, self.tipifier - 1.0) # (n_samples, n_clusters) # Compute base values: (1/d^2_ij * t_ij^(m_p-1)) base_values = (1.0 / D_sq) * T_pow # (n_samples, n_clusters) # Compute power power = 1.0 / (self.fuzzifier - 1.0) # Raise to power powered_values = jnp.power(base_values, power) # (n_samples, n_clusters) # Normalize denominators = jnp.sum(powered_values, axis=1, keepdims=True) # (n_samples, 1) U = powered_values / denominators return U @partial(jit, static_argnums=(0,)) def _compute_centroids( self, X: chex.Array, U: chex.Array, T: chex.Array ) -> chex.Array: """Compute cluster centroids. .. math:: v_j = \\frac{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}^{m_p} \\cdot x_i}{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}^{m_p}} Args: X: Input data, shape (n_samples, n_features) U: Fuzzy membership matrix, shape (n_samples, n_clusters) T: Typicality matrix, shape (n_samples, n_clusters) Returns: centroids: shape (n_clusters, n_features) """ # Fuzzify U and T U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) # Product of memberships weights = U_fuzz * T_fuzz # (n_samples, n_clusters) # Compute centroids numerator = weights.T @ X # (n_clusters, n_features) denominator = jnp.sum(weights, axis=0, keepdims=True).T # (n_clusters, 1) # Handle empty clusters denominator = jnp.maximum(denominator, 1e-10) centroids = numerator / denominator return centroids @partial(jit, static_argnums=(0,)) def _compute_objective( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Compute IPCM objective function. .. math:: J = \\sum_i \\sum_j u_{ij}^{m_f} \\cdot t_{ij}^{m_p} \\cdot d_{ij}^2 + \\sum_j \\gamma_j \\sum_i (1 - t_{ij})^{m_p} \\cdot u_{ij}^{m_f} Args: X: Input data U: Fuzzy membership matrix T: Typicality matrix centroids: Current centroids gamma: Scale parameters Returns: objective: Scalar objective value """ # Compute squared distances D_sq = self.distance_fn(X, centroids) # (n_samples, n_clusters) # Fuzzify U and T U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) # First term: sum_i sum_j [u_ij^m_f * t_ij^m_p * d^2_ij] term1 = jnp.sum(U_fuzz * T_fuzz * D_sq) # Second term: sum_j[gamma_j * sum_i((1-t_ij)^m_p * u_ij^m_f)] one_minus_T = 1.0 - T one_minus_T_fuzz = jnp.power(one_minus_T, self.tipifier) inner_sum = jnp.sum(one_minus_T_fuzz * U_fuzz, axis=0) # (n_clusters,) term2 = jnp.sum(gamma * inner_sum) objective = term1 + term2 return objective @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: IPCMState, X: chex.Array) -> IPCMState: """Single IPCM iteration step. Args: state: Current IPCM state X: Input data Returns: new_state: Updated IPCM state """ # Update T T_new = self._update_T(X, state.centroids, state.gamma) # Update U U_new = self._update_U(X, T_new, state.centroids) # Update centroids centroids_new = self._compute_centroids(X, U_new, T_new) # Compute objective objective = self._compute_objective(X, U_new, T_new, centroids_new, state.gamma) # Check convergence centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro') converged = centroid_change <= self.epsilon new_state = IPCMState( centroids=centroids_new, U=U_new, T=T_new, gamma=state.gamma, objective=objective, iteration=state.iteration + 1, converged=converged, phase=state.phase ) return new_state def _run_phase( self, X: chex.Array, U_init: chex.Array, T_init: chex.Array, centroids_init: chex.Array, gamma_init: chex.Array, phase: int ) -> IPCMState: """Run one phase of IPCM. Args: X: Input data U_init: Initial U matrix T_init: Initial T matrix centroids_init: Initial centroids gamma_init: Initial gamma phase: Phase number (0 or 1) Returns: final_state: Final state after phase convergence """ # Initial objective initial_objective = self._compute_objective(X, U_init, T_init, centroids_init, gamma_init) # Initial state state = IPCMState( centroids=centroids_init, U=U_init, T=T_init, gamma=gamma_init, objective=initial_objective, iteration=0, converged=False, phase=phase ) states_history = [state] objectives = [float(state.objective)] best_state = None best_obj = float('inf') for i in range(self.max_iter): self._notify_iteration(self._build_info(state, state.iteration)) state = self._iteration_step(state, X) states_history.append(state) obj = float(state.objective) objectives.append(obj) if self.restore_best and obj < best_obj: best_obj = obj best_state = state if state.converged: break if self.patience is not None and self._check_patience(objectives, self.patience): break if self.restore_best and best_state is not None: state = best_state self.best_loss_ = best_obj return state, states_history def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = jnp.max(state.U * state.T, axis=1) return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, }
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: """Fit IPCM model to data.""" if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") X = self._validate_input(X) self._notify_fit_start(X) if resume: # Skip phase 0, run only phase 1 with fitted state self._check_fitted() gamma_1 = self._compute_gamma_phase1( X, self.U_, self.T_, self.centroids_ ) state_final, history_phase1 = self._run_phase( X, self.U_, self.T_, self.centroids_, gamma_1, phase=1 ) self._notify_fit_end(self._build_info(state_final, state_final.iteration)) all_objectives = [s.objective for s in history_phase1] self.objective_history_ = jnp.array(all_objectives) else: if initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) # Derive U from centroids, T starts as zeros T_init = jnp.zeros((X.shape[0], self.n_clusters)) gamma_0 = self._compute_gamma_phase0(X, jnp.ones((X.shape[0], self.n_clusters)) / self.n_clusters, centroids_init) T_init = self._update_T(X, centroids_init, gamma_0) U_init = self._update_U(X, T_init, centroids_init) else: U_init, T_init, centroids_init = self._initialize_phase0(X) # Phase 0 gamma_0 = self._compute_gamma_phase0(X, U_init, centroids_init) state_phase0, history_phase0 = self._run_phase( X, U_init, T_init, centroids_init, gamma_0, phase=0 ) # Phase 1 gamma_1 = self._compute_gamma_phase1( X, state_phase0.U, state_phase0.T, state_phase0.centroids ) state_final, history_phase1 = self._run_phase( X, state_phase0.U, state_phase0.T, state_phase0.centroids, gamma_1, phase=1 ) self._notify_fit_end(self._build_info(state_final, state_final.iteration)) all_objectives = [s.objective for s in history_phase0] + [s.objective for s in history_phase1] self.objective_history_ = jnp.array(all_objectives) # Store results self.centroids_ = state_final.centroids self.U_ = state_final.U self.T_ = state_final.T self.gamma_ = state_final.gamma self.n_iter_ = int(state_final.iteration) self.objective_ = float(state_final.objective) return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """Predict cluster labels for new data. Args: X: Input data, shape (n_samples, n_features) Returns: labels: Cluster labels, shape (n_samples,) Raises: ValueError: If model has not been fitted """ self._check_fitted() # Compute T for new data T = self._update_T(X, self.centroids_, self.gamma_) # Compute U for new data U = self._update_U(X, T, self.centroids_) # Assign to cluster with highest membership labels = jnp.argmax(U, axis=1) return labels
[docs] def predict_proba(self, X: chex.Array) -> chex.Array: """Predict fuzzy membership probabilities (U matrix). Args: X: Input data, shape (n_samples, n_features) Returns: U: Fuzzy membership matrix, shape (n_samples, n_clusters) Raises: ValueError: If model has not been fitted """ self._check_fitted() X = jnp.asarray(X) # Compute T for new data T = self._update_T(X, self.centroids_, self.gamma_) # Compute U for new data U = self._update_U(X, T, self.centroids_) return U
def get_typicality(self, X: chex.Array) -> chex.Array: """Compute typicality values (T matrix). Args: X: Input data, shape (n_samples, n_features) Returns: T: Typicality matrix, shape (n_samples, n_clusters) Raises: ValueError: If model has not been fitted """ self._check_fitted() X = jnp.asarray(X) T = self._update_T(X, self.centroids_, self.gamma_) return T