Source code for prosemble.models.ipcm2

"""
JAX-based Improved Possibilistic C-Means 2 (IPCM2) clustering implementation.

This module provides a GPU-accelerated implementation of IPCM2 using JAX
with JIT compilation for high performance.
"""

from typing import NamedTuple, Self
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import chex
from jax import jit, lax

from prosemble.models.fcm import FCM
from prosemble.models.base import FuzzyClusteringBase


class IPCM2State(NamedTuple):
    """Immutable state for IPCM2 iteration.

    Attributes:
        centroids: Cluster centroids, shape (n_clusters, n_features)
        U: Fuzzy membership matrix, shape (n_samples, n_clusters)
        T: Typicality matrix, shape (n_samples, n_clusters)
        gamma: Scale parameters, shape (n_clusters,)
        objective: Current objective function value
        iteration: Current iteration number
        converged: Whether algorithm has converged
        phase: Current phase (0 or 1)
    """
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    gamma: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool
    phase: int


[docs] class IPCM2(FuzzyClusteringBase): """ Improved Possibilistic C-Means 2 clustering with JAX. IPCM2 is a variant of IPCM with key differences: - Uses exponential :math:`T` update: :math:`t_{ij} = \\exp(-d_{ij}^2 / \\gamma_j)` - Centroids use :math:`U^{m_f} \\cdot T` (:math:`T` without power!) - Modified :math:`U` update with exponential distance - Different objective function Algorithm (Phase 0): 1. Initialize :math:`U` using FCM, :math:`T = 0` 2. Compute :math:`\\gamma` parameters from fuzzy membership 3. Update :math:`T` using exponential update 4. Update :math:`U` with modified distance 5. Update centroids using combined U and T weights 6. Repeat until convergence Algorithm (Phase 1): 7. Recompute :math:`\\gamma` using both :math:`U` and :math:`T` 8. Continue iterations with new gamma Objective function: .. math:: J = \\sum_i \\sum_j u_{ij}^{m_f} \\cdot t_{ij} \\cdot d_{ij}^2 + \\sum_j \\gamma_j \\sum_i (t_{ij} \\log t_{ij} - t_{ij} + 1) \\cdot u_{ij}^{m_f} Parameters ---------- fuzzifier : float, default=2.0 Fuzziness parameter for :math:`U` matrix (:math:`m_f`, must be > 1.0). tipifier : float, default=2.0 Possibilistic parameter for :math:`T` matrix (:math:`m_p`, must be > 1.0). init_method : {'fcm'}, default='fcm' Method for initializing :math:`U` matrix. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. Attributes ---------- centroids_ : array, shape (n_clusters, n_features) Final cluster centroids U_ : array, shape (n_samples, n_clusters) Final fuzzy membership matrix T_ : array, shape (n_samples, n_clusters) Final typicality matrix gamma_ : array, shape (n_clusters,) Final scale parameters n_iter_ : int Total number of iterations objective_ : float Final objective function value objective_history_ : array Objective values at each iteration Examples -------- >>> import jax.numpy as jnp >>> from prosemble.models import IPCM2 >>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]]) >>> model = IPCM2(n_clusters=2, random_seed=42) >>> model.fit(X) >>> labels = model.predict(X) """ _hyperparams = ('fuzzifier', 'tipifier', 'init_method') _fitted_array_names = ('U_', 'T_', 'gamma_') def __init__( self, n_clusters: int, fuzzifier: float = 2.0, tipifier: float = 2.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'fcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None, ): # Validate model-specific parameters if fuzzifier <= 1.0: raise ValueError("fuzzifier must be > 1.0") if tipifier <= 1.0: raise ValueError("tipifier must be > 1.0") if init_method != 'fcm': raise ValueError("init_method must be 'fcm' for IPCM2") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.tipifier = tipifier self.init_method = init_method # Fitted attributes self.U_ = None self.T_ = None self.gamma_ = None def _initialize_phase0(self, X: chex.Array): """Initialize for phase 0 using FCM.""" n_samples = X.shape[0] # Initialize using FCM fcm = FCM( n_clusters=self.n_clusters, fuzzifier=self.fuzzifier, max_iter=self.max_iter, epsilon=self.epsilon, random_seed=self.random_seed, distance_fn=self.distance_fn, plot_steps=False ) fcm.fit(X) U = fcm.U_ centroids = fcm.centroids_ # Initialize T as zeros T = jnp.zeros((n_samples, self.n_clusters)) return U, T, centroids @partial(jit, static_argnums=(0,)) def _compute_gamma_phase0( self, X: chex.Array, U: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute :math:`\\gamma` for phase 0. .. math:: \\gamma_j = \\frac{\\sum_i u_{ij}^{m_f} \\cdot d_{ij}^2}{\\sum_i u_{ij}^{m_f}} """ D_sq = self.distance_fn(X, centroids) U_fuzz = jnp.power(U, self.fuzzifier) numerator = jnp.sum(U_fuzz * D_sq, axis=0) denominator = jnp.maximum(jnp.sum(U_fuzz, axis=0), 1e-10) gamma = numerator / denominator return gamma @partial(jit, static_argnums=(0,)) def _compute_gamma_phase1( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array ) -> chex.Array: """Compute :math:`\\gamma` for phase 1. .. math:: \\gamma_j = \\frac{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}^{m_p} \\cdot d_{ij}^2}{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}^{m_p}} """ D_sq = self.distance_fn(X, centroids) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) prod = U_fuzz * T_fuzz numerator = jnp.sum(prod * D_sq, axis=0) denominator = jnp.maximum(jnp.sum(prod, axis=0), 1e-10) gamma = numerator / denominator return gamma @partial(jit, static_argnums=(0,)) def _update_T( self, X: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Update typicality matrix with exponential. .. math:: t_{ij} = \\exp\\left(-\\frac{d_{ij}^2}{\\gamma_j}\\right) """ D_sq = self.distance_fn(X, centroids) D_sq = jnp.maximum(D_sq, 1e-10) # Exponential update ratio = D_sq / gamma[None, :] T = jnp.exp(-ratio) return T @partial(jit, static_argnums=(0,)) def _update_U( self, X: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Update fuzzy membership matrix (IPCM2-specific). .. math:: u_{ij} = \\frac{\\left(\\frac{1}{\\gamma_j (1 - \\exp(-d_{ij}^2/\\gamma_j))}\\right)^{2/(m_f-1)}}{\\sum_k \\left(\\frac{1}{\\gamma_k (1 - \\exp(-d_{ik}^2/\\gamma_k))}\\right)^{2/(m_f-1)}} """ D_sq = self.distance_fn(X, centroids) D_sq = jnp.maximum(D_sq, 1e-10) # Compute modified distance: gamma_j*(1-exp(-d^2_ij/gamma_j)) ratio = D_sq / gamma[None, :] exp_term = jnp.exp(-ratio) modified_dist = gamma[None, :] * (1.0 - exp_term) modified_dist = jnp.maximum(modified_dist, 1e-10) # Compute base values: 1/modified_dist base_values = 1.0 / modified_dist # Compute power power = 2.0 / (self.fuzzifier - 1.0) # Raise to power powered_values = jnp.power(base_values, power) # Normalize denominators = jnp.sum(powered_values, axis=1, keepdims=True) U = powered_values / denominators return U @partial(jit, static_argnums=(0,)) def _compute_centroids( self, X: chex.Array, U: chex.Array, T: chex.Array ) -> chex.Array: """Compute cluster centroids. .. math:: v_j = \\frac{\\sum_i u_{ij}^{m_f} \\cdot t_{ij} \\cdot x_i}{\\sum_i u_{ij}^{m_f} \\cdot t_{ij}} Note: :math:`T` is NOT raised to power :math:`m_p` here! """ U_fuzz = jnp.power(U, self.fuzzifier) # Product of U^m_f and T (NOT T^m_p!) weights = U_fuzz * T # Compute centroids numerator = weights.T @ X denominator = jnp.sum(weights, axis=0, keepdims=True).T denominator = jnp.maximum(denominator, 1e-10) centroids = numerator / denominator return centroids @partial(jit, static_argnums=(0,)) def _compute_objective( self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array, gamma: chex.Array ) -> chex.Array: """Compute IPCM2 objective function. .. math:: J = \\sum_i \\sum_j u_{ij}^{m_f} \\cdot t_{ij} \\cdot d_{ij}^2 + \\sum_j \\gamma_j \\sum_i (t_{ij} \\log t_{ij} - t_{ij} + 1) \\cdot u_{ij}^{m_f} """ D_sq = self.distance_fn(X, centroids) U_fuzz = jnp.power(U, self.fuzzifier) # First term: sum_i sum_j [u_ij^m_f * t_ij * d^2_ij] term1 = jnp.sum(U_fuzz * T * D_sq) # Second term: sum_j[gamma_j * sum_i((t*log(t) - t + 1) * u^m_f)] # Handle log(0) by clamping T T_safe = jnp.maximum(T, 1e-10) entropy_like = T * jnp.log(T_safe) - T + 1.0 inner_sum = jnp.sum(entropy_like * U_fuzz, axis=0) term2 = jnp.sum(gamma * inner_sum) objective = term1 + term2 return objective @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: IPCM2State, X: chex.Array) -> IPCM2State: """Single IPCM2 iteration step.""" # Update T T_new = self._update_T(X, state.centroids, state.gamma) # Update U U_new = self._update_U(X, state.centroids, state.gamma) # Update centroids centroids_new = self._compute_centroids(X, U_new, T_new) # Compute objective objective = self._compute_objective(X, U_new, T_new, centroids_new, state.gamma) # Check convergence centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro') converged = centroid_change <= self.epsilon new_state = IPCM2State( centroids=centroids_new, U=U_new, T=T_new, gamma=state.gamma, objective=objective, iteration=state.iteration + 1, converged=converged, phase=state.phase ) return new_state def _run_phase( self, X: chex.Array, U_init: chex.Array, T_init: chex.Array, centroids_init: chex.Array, gamma_init: chex.Array, phase: int ): """Run one phase of IPCM2.""" initial_objective = self._compute_objective(X, U_init, T_init, centroids_init, gamma_init) state = IPCM2State( centroids=centroids_init, U=U_init, T=T_init, gamma=gamma_init, objective=initial_objective, iteration=0, converged=False, phase=phase ) states_history = [state] objectives = [float(state.objective)] best_state = None best_obj = float('inf') for i in range(self.max_iter): self._notify_iteration(self._build_info(state, state.iteration)) state = self._iteration_step(state, X) states_history.append(state) obj = float(state.objective) objectives.append(obj) if self.restore_best and obj < best_obj: best_obj = obj best_state = state if state.converged: break if self.patience is not None and self._check_patience(objectives, self.patience): break if self.restore_best and best_state is not None: state = best_state self.best_loss_ = best_obj return state, states_history def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = jnp.max(state.U * state.T, axis=1) return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, }
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: """Fit IPCM2 model to data.""" if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") X = self._validate_input(X) self._notify_fit_start(X) if resume: # Skip phase 0, run only phase 1 with fitted state self._check_fitted() gamma_1 = self._compute_gamma_phase1( X, self.U_, self.T_, self.centroids_ ) state_final, history_phase1 = self._run_phase( X, self.U_, self.T_, self.centroids_, gamma_1, phase=1 ) self._notify_fit_end(self._build_info(state_final, state_final.iteration)) all_objectives = [s.objective for s in history_phase1] self.objective_history_ = jnp.array(all_objectives) else: if initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) # Derive U and T from centroids gamma_0 = self._compute_gamma_phase0(X, jnp.ones((X.shape[0], self.n_clusters)) / self.n_clusters, centroids_init) T_init = self._update_T(X, centroids_init, gamma_0) U_init = self._update_U(X, centroids_init, gamma_0) else: U_init, T_init, centroids_init = self._initialize_phase0(X) # Phase 0 gamma_0 = self._compute_gamma_phase0(X, U_init, centroids_init) state_phase0, history_phase0 = self._run_phase( X, U_init, T_init, centroids_init, gamma_0, phase=0 ) # Phase 1 gamma_1 = self._compute_gamma_phase1( X, state_phase0.U, state_phase0.T, state_phase0.centroids ) state_final, history_phase1 = self._run_phase( X, state_phase0.U, state_phase0.T, state_phase0.centroids, gamma_1, phase=1 ) self._notify_fit_end(self._build_info(state_final, state_final.iteration)) all_objectives = [s.objective for s in history_phase0] + [s.objective for s in history_phase1] self.objective_history_ = jnp.array(all_objectives) self.centroids_ = state_final.centroids self.U_ = state_final.U self.T_ = state_final.T self.gamma_ = state_final.gamma self.n_iter_ = int(state_final.iteration) self.objective_ = float(state_final.objective) return self
[docs] def predict(self, X: chex.Array) -> chex.Array: """Predict cluster labels for new data.""" self._check_fitted() T = self._update_T(X, self.centroids_, self.gamma_) U = self._update_U(X, self.centroids_, self.gamma_) labels = jnp.argmax(U, axis=1) return labels
[docs] def predict_proba(self, X: chex.Array) -> chex.Array: """Predict fuzzy membership probabilities.""" self._check_fitted() X = jnp.asarray(X) U = self._update_U(X, self.centroids_, self.gamma_) return U
def get_typicality(self, X: chex.Array) -> chex.Array: """Compute typicality values.""" self._check_fitted() X = jnp.asarray(X) T = self._update_T(X, self.centroids_, self.gamma_) return T