Source code for prosemble.models.pfcm

"""
JAX implementation of Possibilistic Fuzzy C-Means (PFCM) clustering algorithm.

PFCM combines FCM and PCM by using both fuzzy membership (:math:`U`) and typicality (:math:`T`) matrices.
It provides better clustering by simultaneously considering membership and typicality.

Mathematical Background:
-----------------------
PFCM uses both membership and typicality with weighting parameters :math:`a` and :math:`b`.

Objective Function:

.. math::

    J = \\sum_i \\sum_j \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}^\\eta\\right] \\|x_i - v_j\\|^2

Centroid Update:

.. math::

    v_j = \\frac{\\sum_i \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}^\\eta\\right] x_i}{\\sum_i \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}^\\eta\\right]}

Membership Update (like FCM):

.. math::

    u_{ij} = \\frac{1}{\\sum_k \\left(\\frac{d_{ij}}{d_{ik}}\\right)^{2/(m-1)}}

Typicality Update (like PCM):

.. math::

    t_{ij} = \\frac{1}{1 + \\left(\\frac{b \\cdot \\|x_i - v_j\\|^2}{\\gamma_j}\\right)^{1/(\\eta-1)}}

Gamma:

.. math::

    \\gamma_j = k \\cdot \\frac{\\sum_i u_{ij}^m \\cdot \\|x_i - v_j\\|^2}{\\sum_i u_{ij}^m}

Author: Prosemble Contributors
License: MIT
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
from functools import partial
from typing import NamedTuple, Self
import chex

from .fcm import FCM
from .base import FuzzyClusteringBase, ScanFitMixin


class PFCMState(NamedTuple):
    """
    Immutable state for PFCM algorithm.

    Attributes:
        centroids: (c, d) cluster centroids
        U: (n, c) fuzzy membership matrix
        T: (n, c) typicality matrix
        gamma: (c,) scale parameters
        objective: Scalar objective value
        iteration: Current iteration
        converged: Boolean convergence flag
    """
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    gamma: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool


[docs] class PFCM(ScanFitMixin, FuzzyClusteringBase): """ JAX implementation of Possibilistic Fuzzy C-Means clustering. PFCM combines fuzzy membership and typicality for robust clustering. Parameters ---------- fuzzifier : float, default=2.0 Fuzzification parameter for membership (:math:`m`). Must be > 1. eta : float, default=2.0 Fuzzification parameter for typicality (:math:`\\eta`). Must be > 1. a : float, default=1.0 Weight for fuzzy membership term. Must be >= 0. b : float, default=1.0 Weight for typicality term. Must be >= 0. k : float, default=1.0 Parameter for :math:`\\gamma` computation. Must be > 0. init_method : str, default='fcm' Initialization method: 'fcm' or 'random'. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. Attributes ---------- centroids_ : array Final cluster centroids U_ : array Final fuzzy membership matrix T_ : array Final typicality matrix gamma_ : array Final scale parameters objective_ : float Final objective value n_iter_ : int Number of iterations performed """ _hyperparams = ('fuzzifier', 'a', 'b', 'eta', 'k', 'init_method') _fitted_array_names = ('U_', 'T_', 'gamma_') def __init__( self, n_clusters: int, fuzzifier: float = 2.0, eta: float = 2.0, a: float = 1.0, b: float = 1.0, k: float = 1.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'fcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str = None, callbacks=None, ): # Model-specific validation first if fuzzifier <= 1.0: raise ValueError(f"fuzzifier must be > 1, got {fuzzifier}") if eta <= 1.0: raise ValueError(f"eta must be > 1, got {eta}") if a < 0 or b < 0: raise ValueError(f"a and b must be >= 0, got a={a}, b={b}") if k <= 0: raise ValueError(f"k must be > 0, got {k}") if init_method not in ['fcm', 'random']: raise ValueError(f"init_method must be 'fcm' or 'random', got {init_method}") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.eta = eta self.a = a self.b = b self.k = k self.init_method = init_method # Model-specific fitted attributes self.U_ = None self.T_ = None self.gamma_ = None self.history_ = None @partial(jit, static_argnums=(0,)) def _initialize_matrices(self, X: chex.Array, key: chex.PRNGKey) -> tuple[chex.Array, chex.Array]: """Initialize :math:`U` and :math:`T` matrices.""" n_samples = X.shape[0] # Initialize both U and T using Dirichlet key1, key2 = jax.random.split(key) U = jax.random.dirichlet(key1, alpha=jnp.ones(self.n_clusters), shape=(n_samples,)) T = jax.random.dirichlet(key2, alpha=jnp.ones(self.n_clusters), shape=(n_samples,)) return U, T @partial(jit, static_argnums=(0,)) def _compute_centroids(self, X: chex.Array, U: chex.Array, T: chex.Array) -> chex.Array: """ Compute centroids using weighted combination of U and T. .. math:: v_j = \\frac{\\sum_i \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}^\\eta\\right] x_i}{\\sum_i \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}^\\eta\\right]} """ U_fuzz = jnp.power(U, self.fuzzifier) # (n, c) T_fuzz = jnp.power(T, self.eta) # (n, c) # Weighted combination weights = self.a * U_fuzz + self.b * T_fuzz # (n, c) numerator = weights.T @ X # (c, d) denominator = jnp.sum(weights, axis=0, keepdims=True).T # (c, 1) denominator = jnp.maximum(denominator, 1e-10) centroids = numerator / denominator return centroids @partial(jit, static_argnums=(0,)) def _update_U(self, X: chex.Array, centroids: chex.Array) -> chex.Array: """ Update fuzzy membership matrix (same as FCM). .. math:: u_{ij} = \\frac{1}{\\sum_k \\left(\\frac{d_{ij}}{d_{ik}}\\right)^{2/(m-1)}} """ D = self.distance_fn(X, centroids) D = jnp.maximum(D, 1e-10) power = 1.0 / (self.fuzzifier - 1) ratios = jnp.power(D[:, :, None] / D[:, None, :], power) denominators = jnp.sum(ratios, axis=2) U = 1.0 / denominators U = U / jnp.sum(U, axis=1, keepdims=True) return U @partial(jit, static_argnums=(0,)) def _compute_gamma(self, X: chex.Array, U: chex.Array, centroids: chex.Array) -> chex.Array: """ Compute gamma parameters using fuzzy membership. .. math:: \\gamma_j = k \\cdot \\frac{\\sum_i u_{ij}^m \\cdot \\|x_i - v_j\\|^2}{\\sum_i u_{ij}^m} """ D_sq = self.distance_fn(X, centroids) # (n, c) U_fuzz = jnp.power(U, self.fuzzifier) # (n, c) numerator = jnp.sum(U_fuzz * D_sq, axis=0) # (c,) denominator = jnp.sum(U_fuzz, axis=0) # (c,) denominator = jnp.maximum(denominator, 1e-10) gamma = self.k * numerator / denominator gamma = jnp.maximum(gamma, 1e-10) return gamma @partial(jit, static_argnums=(0,)) def _update_T(self, X: chex.Array, centroids: chex.Array, gamma: chex.Array) -> chex.Array: """ Update typicality matrix. .. math:: t_{ij} = \\frac{1}{1 + \\left(\\frac{b \\cdot \\|x_i - v_j\\|^2}{\\gamma_j}\\right)^{1/(\\eta-1)}} """ D_sq = self.distance_fn(X, centroids) # (n, c) power = 1.0 / (self.eta - 1) ratio = self.b * D_sq / gamma[None, :] # (n, c) ratio = jnp.maximum(ratio, 0) T = 1.0 / (1.0 + jnp.power(ratio, power)) return T @partial(jit, static_argnums=(0,)) def _compute_objective( self, X: chex.Array, centroids: chex.Array, U: chex.Array, T: chex.Array ) -> chex.Array: """ Compute PFCM objective function. .. math:: J = \\sum_i \\sum_j \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}^\\eta\\right] \\|x_i - v_j\\|^2 """ D_sq = self.distance_fn(X, centroids) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.eta) weights = self.a * U_fuzz + self.b * T_fuzz objective = jnp.sum(weights * D_sq) return objective @partial(jit, static_argnums=(0,)) def _check_convergence( self, centroids_old: chex.Array, centroids_new: chex.Array ) -> chex.Array: """Check convergence based on centroid change.""" diff = jnp.linalg.norm(centroids_new - centroids_old, ord='fro') return diff < self.epsilon @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: PFCMState, X: chex.Array) -> tuple[PFCMState, dict]: """Single iteration of PFCM.""" # Update U (membership) U_new = self._update_U(X, state.centroids) # Compute gamma gamma_new = self._compute_gamma(X, U_new, state.centroids) # Update T (typicality) T_new = self._update_T(X, state.centroids, gamma_new) # Compute new centroids centroids_new = self._compute_centroids(X, U_new, T_new) # Compute objective obj_new = self._compute_objective(X, centroids_new, U_new, T_new) # Check convergence converged = self._check_convergence(state.centroids, centroids_new) new_state = PFCMState( centroids=centroids_new, U=U_new, T=T_new, gamma=gamma_new, objective=obj_new, iteration=state.iteration + 1, converged=converged ) metrics = { 'objective': obj_new, 'converged': converged } return new_state, metrics def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = (jnp.max(state.U, axis=1) + jnp.max(state.T, axis=1)) / 2.0 return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, } def _initialize_from_fcm(self, X: chex.Array) -> tuple[chex.Array, chex.Array, chex.Array]: """Initialize using FCM results.""" fcm = FCM( n_clusters=self.n_clusters, fuzzifier=self.fuzzifier, max_iter=min(50, self.max_iter), epsilon=self.epsilon, random_seed=42, distance_fn=self.distance_fn, ) fcm.fit(X) centroids_init = fcm.centroids_ U_init = fcm.U_ # Initialize T similar to U T_init = U_init.copy() return centroids_init, U_init, T_init
[docs] def fit(self, X: jnp.ndarray, initial_centroids=None, resume=False) -> Self: """Fit PFCM model to data. Parameters ---------- X : array-like, shape (n_samples, n_features) Training data initial_centroids : array-like, shape (n_clusters, n_features), optional Pre-computed centroids for warm starting resume : bool, default=False If True, resume from the model's current fitted state """ X = self._validate_input(X) if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") if not jnp.all(jnp.isfinite(X)): raise ValueError("X contains NaN or Inf values") # Initialize if resume: self._check_fitted() centroids_init = self.centroids_ U_init = self.U_ T_init = self.T_ gamma_init = self.gamma_ elif initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) U_init = self._update_U(X, centroids_init) gamma_init = self._compute_gamma(X, U_init, centroids_init) T_init = self._update_T(X, centroids_init, gamma_init) else: if self.init_method == 'fcm': centroids_init, U_init, T_init = self._initialize_from_fcm(X) else: self.key, subkey = jax.random.split(self.key) U_init, T_init = self._initialize_matrices(X, subkey) centroids_init = self._compute_centroids(X, U_init, T_init) gamma_init = self._compute_gamma(X, U_init, centroids_init) obj_init = self._compute_objective(X, centroids_init, U_init, T_init) initial_state = PFCMState( centroids=centroids_init, U=U_init, T=T_init, gamma=gamma_init, objective=obj_init, iteration=0, converged=False ) final_state, self.history_ = self._run_training(X, initial_state) # Store results self.centroids_ = final_state.centroids self.U_ = final_state.U self.T_ = final_state.T self.gamma_ = final_state.gamma self.objective_ = final_state.objective self.n_iter_ = int(final_state.iteration) return self
[docs] def predict(self, X: jnp.ndarray) -> jnp.ndarray: """Predict cluster labels.""" self._check_fitted() X = jnp.asarray(X) U = self._update_U(X, self.centroids_) labels = jnp.argmax(U, axis=1) return labels
[docs] def predict_proba(self, X: jnp.ndarray) -> jnp.ndarray: """Predict fuzzy membership.""" self._check_fitted() X = jnp.asarray(X) return self._update_U(X, self.centroids_)
[docs] def predict_typicality(self, X: jnp.ndarray) -> jnp.ndarray: """Predict typicality values.""" self._check_fitted() self._check_fitted('gamma_') X = jnp.asarray(X) return self._update_T(X, self.centroids_, self.gamma_)
def get_objective_history(self) -> jnp.ndarray: """Return objective function history.""" if self.history_ is None: raise RuntimeError("Model not fitted yet. Call fit() first.") return self.history_['objective']