Source code for prosemble.models.kipcm

"""
JAX-based Kernel Improved Possibilistic C-Means (KIPCM) clustering implementation.
"""

from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import chex
from jax import jit
from prosemble.core.kernel import batch_gaussian_kernel
from prosemble.models.kfcm import KFCM
from prosemble.models.base import FuzzyClusteringBase


class KIPCMState(NamedTuple):
    centroids: chex.Array
    U: chex.Array
    T: chex.Array
    gamma: chex.Array
    objective: chex.Array
    iteration: int
    converged: bool
    phase: int


[docs] class KIPCM(FuzzyClusteringBase): """Kernel Improved Possibilistic C-Means with JAX. KIPCM uses two-phase approach in kernel space with product-based centroids. Parameters ---------- fuzzifier : float, default=2.0 Fuzziness parameter for :math:`U` matrix (must be > 1.0). tipifier : float, default=2.0 Possibilistic parameter for :math:`T` matrix (must be > 1.0). k : float, default=1.0 Scaling parameter for :math:`\\gamma` in phase 1 (must be > 0). sigma : float, default=1.0 Kernel bandwidth parameter (must be > 0). init_method : {'kfcm'}, default='kfcm' Initialization method. n_clusters : int Number of clusters (must be >= 2). max_iter : int Maximum number of iterations. epsilon : float Convergence threshold. random_seed : int Random seed for reproducibility. distance_fn : callable, optional Pairwise distance function. Default: squared Euclidean. patience : int, optional Epochs with no improvement before early stopping. Default: None. restore_best : bool If True, restore centroids from the lowest-objective epoch. Default: False. plot_steps : bool Whether to visualize clustering progress. Default: False. show_confidence : bool Whether to show confidence in visualization. Default: True. show_pca_variance : bool Whether to show PCA variance in visualization. Default: True. save_plot_path : str, optional Path to save final plot. callbacks : list, optional List of Callback objects for monitoring/visualization. """ _hyperparams = ('fuzzifier', 'tipifier', 'k', 'sigma', 'init_method') _fitted_array_names = ('U_', 'T_', 'gamma_') def __init__(self, n_clusters: int, fuzzifier: float = 2.0, tipifier: float = 2.0, k: float = 1.0, sigma: float = 1.0, max_iter: int = 100, epsilon: float = 1e-5, init_method: str = 'kfcm', random_seed: int = 42, distance_fn=None, patience: int | None = None, restore_best: bool = False, plot_steps: bool = False, show_confidence: bool = True, show_pca_variance: bool = True, save_plot_path: str | None = None, callbacks=None): # Validate model-specific parameters if fuzzifier <= 1.0: raise ValueError("fuzzifier must be > 1.0") if tipifier <= 1.0: raise ValueError("tipifier must be > 1.0") if k <= 0: raise ValueError("k must be > 0") if sigma <= 0: raise ValueError("sigma must be > 0") super().__init__( n_clusters=n_clusters, max_iter=max_iter, epsilon=epsilon, random_seed=random_seed, distance_fn=distance_fn, patience=patience, restore_best=restore_best, plot_steps=plot_steps, show_confidence=show_confidence, show_pca_variance=show_pca_variance, save_plot_path=save_plot_path, callbacks=callbacks, ) self.fuzzifier = fuzzifier self.tipifier = tipifier self.k = k self.sigma = sigma self.init_method = init_method # Fitted attributes self.U_ = None self.T_ = None self.gamma_ = None def _initialize_phase0(self, X: chex.Array): n_samples = X.shape[0] kfcm = KFCM(self.n_clusters, self.fuzzifier, self.sigma, self.max_iter, self.epsilon, 'random', self.random_seed, False) kfcm.fit(X) U = kfcm.U_ centroids = kfcm.centroids_ T = jnp.zeros((n_samples, self.n_clusters)) return U, T, centroids @partial(jit, static_argnums=(0,)) def _compute_gamma_phase0(self, X: chex.Array, U: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) numerator = jnp.sum(U_fuzz * kernel_dist, axis=0) denominator = jnp.maximum(jnp.sum(U_fuzz, axis=0), 1e-10) return numerator / denominator @partial(jit, static_argnums=(0,)) def _compute_gamma_phase1(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) prod = U_fuzz * T_fuzz numerator = jnp.sum(prod * kernel_dist, axis=0) denominator = jnp.maximum(jnp.sum(prod, axis=0), 1e-10) return self.k * numerator / denominator @partial(jit, static_argnums=(0,)) def _update_T(self, X: chex.Array, centroids: chex.Array, gamma: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) kernel_dist = jnp.maximum(kernel_dist, 1e-10) power = 1.0 / (self.tipifier - 1.0) ratio = kernel_dist / gamma[None, :] return 1.0 / (1.0 + jnp.power(ratio, power)) @partial(jit, static_argnums=(0,)) def _update_U(self, X: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) kernel_dist = jnp.maximum(kernel_dist, 1e-10) T_pow = jnp.power(T, self.tipifier - 1.0) base_values = (1.0 / kernel_dist) * T_pow power = 1.0 / (self.fuzzifier - 1.0) powered_values = jnp.power(base_values, power) denominators = jnp.sum(powered_values, axis=1, keepdims=True) return powered_values / denominators @partial(jit, static_argnums=(0,)) def _compute_centroids(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) weights = (U_fuzz * T_fuzz) * K numerator = weights.T @ X denominator = jnp.sum(weights, axis=0, keepdims=True).T return numerator / jnp.maximum(denominator, 1e-10) @partial(jit, static_argnums=(0,)) def _compute_objective(self, X: chex.Array, U: chex.Array, T: chex.Array, centroids: chex.Array, gamma: chex.Array) -> chex.Array: K = batch_gaussian_kernel(X, centroids, self.sigma) kernel_dist = 2.0 * (1.0 - K) U_fuzz = jnp.power(U, self.fuzzifier) T_fuzz = jnp.power(T, self.tipifier) term1 = jnp.sum(U_fuzz * T_fuzz * kernel_dist) one_minus_T = 1.0 - T one_minus_T_fuzz = jnp.power(one_minus_T, self.tipifier) inner_sum = jnp.sum(one_minus_T_fuzz * U_fuzz, axis=0) term2 = jnp.sum(gamma * inner_sum) return term1 + term2 @partial(jit, static_argnums=(0,)) def _iteration_step(self, state: KIPCMState, X: chex.Array) -> KIPCMState: T_new = self._update_T(X, state.centroids, state.gamma) U_new = self._update_U(X, T_new, state.centroids) centroids_new = self._compute_centroids(X, U_new, T_new, state.centroids) objective = self._compute_objective(X, U_new, T_new, centroids_new, state.gamma) centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro') converged = centroid_change <= self.epsilon return KIPCMState(centroids_new, U_new, T_new, state.gamma, objective, state.iteration + 1, converged, state.phase) def _build_info(self, state, iteration): labels = jnp.argmax(state.U, axis=1) weights = jnp.max(state.U * state.T, axis=1) return { 'centroids': state.centroids, 'labels': labels, 'weights': weights, 'iteration': iteration, 'objective': float(state.objective), 'max_iter': self.max_iter, } def _run_phase(self, X: chex.Array, U_init: chex.Array, T_init: chex.Array, centroids_init: chex.Array, gamma_init: chex.Array, phase: int): initial_objective = self._compute_objective(X, U_init, T_init, centroids_init, gamma_init) state = KIPCMState(centroids_init, U_init, T_init, gamma_init, initial_objective, 0, False, phase) states_history = [state] for i in range(self.max_iter): self._notify_iteration(self._build_info(state, state.iteration)) state = self._iteration_step(state, X) states_history.append(state) if state.converged: break return state, states_history
[docs] def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self: if resume and initial_centroids is not None: raise ValueError("Cannot use both resume=True and initial_centroids") X = self._validate_input(X) self._notify_fit_start(X) if resume: self._check_fitted() gamma_1 = self._compute_gamma_phase1(X, self.U_, self.T_, self.centroids_) state_final, history_phase1 = self._run_phase(X, self.U_, self.T_, self.centroids_, gamma_1, phase=1) self._notify_fit_end(self._build_info(state_final, state_final.iteration)) all_objectives = [s.objective for s in history_phase1] self.objective_history_ = jnp.array(all_objectives) else: if initial_centroids is not None: centroids_init = self._validate_initial_centroids(X, initial_centroids) T_init = jnp.zeros((X.shape[0], self.n_clusters)) U_uniform = jnp.ones((X.shape[0], self.n_clusters)) / self.n_clusters gamma_0 = self._compute_gamma_phase0(X, U_uniform, centroids_init) T_init = self._update_T(X, centroids_init, gamma_0) U_init = self._update_U(X, T_init, centroids_init) else: U_init, T_init, centroids_init = self._initialize_phase0(X) gamma_0 = self._compute_gamma_phase0(X, U_init, centroids_init) state_phase0, history_phase0 = self._run_phase(X, U_init, T_init, centroids_init, gamma_0, phase=0) gamma_1 = self._compute_gamma_phase1(X, state_phase0.U, state_phase0.T, state_phase0.centroids) state_final, history_phase1 = self._run_phase(X, state_phase0.U, state_phase0.T, state_phase0.centroids, gamma_1, phase=1) self._notify_fit_end(self._build_info(state_final, state_final.iteration)) all_objectives = [s.objective for s in history_phase0] + [s.objective for s in history_phase1] self.objective_history_ = jnp.array(all_objectives) self.centroids_ = state_final.centroids self.U_ = state_final.U self.T_ = state_final.T self.gamma_ = state_final.gamma self.n_iter_ = int(state_final.iteration) self.objective_ = float(state_final.objective) return self
[docs] def predict(self, X: chex.Array) -> chex.Array: self._check_fitted() T = self._update_T(X, self.centroids_, self.gamma_) U = self._update_U(X, T, self.centroids_) return jnp.argmax(U, axis=1)
[docs] def predict_proba(self, X: chex.Array) -> chex.Array: self._check_fitted() X = jnp.asarray(X) T = self._update_T(X, self.centroids_, self.gamma_) return self._update_U(X, T, self.centroids_)
def get_typicality(self, X: chex.Array) -> chex.Array: self._check_fitted() X = jnp.asarray(X) return self._update_T(X, self.centroids_, self.gamma_)