"""
JAX-based Adaptive Fuzzy C-Means (AFCM) clustering implementation.
This module provides a GPU-accelerated implementation of AFCM using JAX
with JIT compilation for high performance.
"""
from typing import NamedTuple, Self
from functools import partial
import jax
import jax.numpy as jnp
import chex
from jax import jit, lax
from prosemble.models.base import FuzzyClusteringBase, ScanFitMixin
from prosemble.models.fcm import FCM
class AFCMState(NamedTuple):
"""Immutable state for AFCM iteration.
Attributes:
centroids: Cluster centroids, shape (n_clusters, n_features)
U: Fuzzy membership matrix, shape (n_samples, n_clusters)
T: Typicality matrix, shape (n_samples, n_clusters)
gamma: Scale parameters, shape (n_clusters,)
objective: Current objective function value
iteration: Current iteration number
converged: Whether algorithm has converged
"""
centroids: chex.Array
U: chex.Array
T: chex.Array
gamma: chex.Array
objective: chex.Array
iteration: int
converged: bool
[docs]
class AFCM(ScanFitMixin, FuzzyClusteringBase):
"""
Adaptive Fuzzy C-Means clustering with JAX.
AFCM is an adaptive variant that combines fuzzy and possibilistic approaches
with specific parameter combinations.
Key features:
- Centroids use :math:`a \\cdot U^m + b \\cdot T` (:math:`T` to power 1, not :math:`m`!)
- :math:`\\gamma` computed with Euclidean distance (not squared)
- Exponential :math:`T` update with parameter :math:`b`
- Standard FCM :math:`U` update
Algorithm:
1. Initialize :math:`U` using FCM
2. Compute :math:`\\gamma` parameters using Euclidean distance
3. Update :math:`T` using exponential update
4. Update :math:`U` using standard FCM rule
5. Update centroids using combined fuzzy-possibilistic weights
6. Repeat until convergence
Objective function:
.. math::
J = \\sum_i \\sum_j \\left[d_{ij}^2 \\cdot (a \\cdot u_{ij}^m + b \\cdot t_{ij})\\right] + \\sum_j \\left[\\gamma_j \\cdot \\sum_i (t_{ij} \\log t_{ij} - t_{ij})\\right]
Parameters
----------
fuzzifier : float, default=2.0
Fuzziness parameter (must be > 1.0).
a : float, default=1.0
Weight for fuzzy membership term (must be > 0).
b : float, default=1.0
Weight for typicality term (must be > 0).
k : float, default=1.0
Scaling parameter for :math:`\\gamma` (must be > 0).
init_method : {'fcm'}, default='fcm'
Initialization method.
n_clusters : int
Number of clusters (must be >= 2).
max_iter : int
Maximum number of iterations.
epsilon : float
Convergence threshold.
random_seed : int
Random seed for reproducibility.
distance_fn : callable, optional
Pairwise distance function. Default: squared Euclidean.
patience : int, optional
Epochs with no improvement before early stopping. Default: None.
restore_best : bool
If True, restore centroids from the lowest-objective epoch.
Default: False.
plot_steps : bool
Whether to visualize clustering progress. Default: False.
show_confidence : bool
Whether to show confidence in visualization. Default: True.
show_pca_variance : bool
Whether to show PCA variance in visualization. Default: True.
save_plot_path : str, optional
Path to save final plot.
callbacks : list, optional
List of Callback objects for monitoring/visualization.
Attributes
----------
centroids_ : array, shape (n_clusters, n_features)
Final cluster centroids
U_ : array, shape (n_samples, n_clusters)
Final fuzzy membership matrix
T_ : array, shape (n_samples, n_clusters)
Final typicality matrix
gamma_ : array, shape (n_clusters,)
Final scale parameters
n_iter_ : int
Number of iterations until convergence
objective_ : float
Final objective function value
objective_history_ : array
Objective values at each iteration
Examples
--------
>>> import jax.numpy as jnp
>>> from prosemble.models import AFCM
>>> X = jnp.array([[1, 2], [1.5, 1.8], [5, 8], [8, 8]])
>>> model = AFCM(n_clusters=2, fuzzifier=2.0, a=1.0, b=1.0, k=1.0, random_seed=42)
>>> model.fit(X)
>>> labels = model.predict(X)
"""
_hyperparams = ('fuzzifier', 'a', 'b', 'k', 'init_method')
_fitted_array_names = ('U_', 'T_', 'gamma_')
def __init__(
self,
n_clusters: int,
fuzzifier: float = 2.0,
a: float = 1.0,
b: float = 1.0,
k: float = 1.0,
max_iter: int = 100,
epsilon: float = 1e-5,
init_method: str = 'fcm',
random_seed: int = 42,
distance_fn=None,
patience: int | None = None,
restore_best: bool = False,
plot_steps: bool = False,
show_confidence: bool = True,
show_pca_variance: bool = True,
save_plot_path: str | None = None,
callbacks=None,
):
# Validate model-specific parameters
if fuzzifier <= 1.0:
raise ValueError("fuzzifier must be > 1.0")
if a <= 0:
raise ValueError("a must be > 0")
if b <= 0:
raise ValueError("b must be > 0")
if k <= 0:
raise ValueError("k must be > 0")
if init_method != 'fcm':
raise ValueError("init_method must be 'fcm' for AFCM")
super().__init__(
n_clusters=n_clusters,
max_iter=max_iter,
epsilon=epsilon,
random_seed=random_seed,
distance_fn=distance_fn,
patience=patience,
restore_best=restore_best,
plot_steps=plot_steps,
show_confidence=show_confidence,
show_pca_variance=show_pca_variance,
save_plot_path=save_plot_path,
callbacks=callbacks,
)
self.fuzzifier = fuzzifier
self.a = a
self.b = b
self.k = k
self.init_method = init_method
# Model-specific fitted attributes
self.U_ = None
self.T_ = None
self.gamma_ = None
def _initialize(self, X: chex.Array):
"""Initialize using FCM."""
n_samples = X.shape[0]
# Initialize using FCM
fcm = FCM(
n_clusters=self.n_clusters,
fuzzifier=self.fuzzifier,
max_iter=self.max_iter,
epsilon=self.epsilon,
random_seed=self.random_seed,
distance_fn=self.distance_fn,
plot_steps=False
)
fcm.fit(X)
U = fcm.U_
centroids = fcm.centroids_
# Initialize T as zeros
T = jnp.zeros((n_samples, self.n_clusters))
return U, T, centroids
@partial(jit, static_argnums=(0,))
def _compute_gamma(
self, X: chex.Array, U: chex.Array, centroids: chex.Array
) -> chex.Array:
"""Compute :math:`\\gamma` using Euclidean distance (not squared!).
.. math::
\\gamma_j = k \\cdot \\frac{\\sum_i u_{ij}^m \\cdot d_{ij}}{\\sum_i u_{ij}^m}
"""
D_sq = self.distance_fn(X, centroids)
D = jnp.sqrt(jnp.maximum(D_sq, 1e-10)) # Euclidean distance
U_fuzz = jnp.power(U, self.fuzzifier)
numerator = jnp.sum(U_fuzz * D, axis=0)
denominator = jnp.sum(U_fuzz, axis=0)
gamma = self.k * numerator / denominator
return gamma
@partial(jit, static_argnums=(0,))
def _update_T(
self, X: chex.Array, centroids: chex.Array, gamma: chex.Array
) -> chex.Array:
"""Update typicality matrix with exponential and parameter :math:`b`.
.. math::
t_{ij} = \\exp\\left(-\\frac{b \\cdot d_{ij}^2}{\\gamma_j}\\right)
"""
D_sq = self.distance_fn(X, centroids)
D_sq = jnp.maximum(D_sq, 1e-10)
# Exponential update with b parameter
ratio = self.b * D_sq / gamma[None, :]
T = jnp.exp(-ratio)
return T
@partial(jit, static_argnums=(0,))
def _update_U(
self, X: chex.Array, centroids: chex.Array
) -> chex.Array:
"""Update fuzzy membership matrix (standard FCM)."""
# Compute distances
D = self.distance_fn(X, centroids)
D = jnp.maximum(D, 1e-10)
# Compute power for FCM update
power = 1.0 / (self.fuzzifier - 1.0)
# Compute distance ratios
def compute_membership_row(distances_i):
ratios = distances_i[:, None] / distances_i[None, :]
powered_ratios = jnp.power(ratios, power)
denominators = jnp.sum(powered_ratios, axis=1)
memberships = 1.0 / denominators
return memberships
U = jax.vmap(compute_membership_row)(D)
# Normalize
U = U / jnp.sum(U, axis=1, keepdims=True)
return U
@partial(jit, static_argnums=(0,))
def _compute_centroids(
self, X: chex.Array, U: chex.Array, T: chex.Array
) -> chex.Array:
"""Compute cluster centroids.
.. math::
v_j = \\frac{\\sum_i \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}\\right] x_i}{\\sum_i \\left[a \\cdot u_{ij}^m + b \\cdot t_{ij}\\right]}
Note: :math:`T` is NOT raised to a power!
"""
U_fuzz = jnp.power(U, self.fuzzifier)
# Combined weights: a*U^m + b*T
weights = self.a * U_fuzz + self.b * T
# Compute centroids
numerator = weights.T @ X
denominator = jnp.sum(weights, axis=0, keepdims=True).T
denominator = jnp.maximum(denominator, 1e-10)
centroids = numerator / denominator
return centroids
@partial(jit, static_argnums=(0,))
def _compute_objective(
self, X: chex.Array, U: chex.Array, T: chex.Array,
centroids: chex.Array, gamma: chex.Array
) -> chex.Array:
"""Compute AFCM objective function.
.. math::
J = \\sum_i \\sum_j \\left[d_{ij}^2 \\cdot (a \\cdot u_{ij}^m + b \\cdot t_{ij})\\right] + \\sum_j \\left[\\gamma_j \\cdot \\sum_i (t_{ij} \\log t_{ij} - t_{ij})\\right]
"""
D_sq = self.distance_fn(X, centroids)
U_fuzz = jnp.power(U, self.fuzzifier)
# First term: sum_i sum_j [d^2_ij * (a*u_ij^m + b*t_ij)]
weights = self.a * U_fuzz + self.b * T
term1 = jnp.sum(D_sq * weights)
# Second term: sum_j[gamma_j * sum_i(t*log(t) - t)]
T_safe = jnp.maximum(T, 1e-10)
entropy_like = T * jnp.log(T_safe) - T
inner_sum = jnp.sum(entropy_like, axis=0)
term2 = jnp.sum(gamma * inner_sum)
objective = term1 + term2
return objective
@partial(jit, static_argnums=(0,))
def _iteration_step(
self, state: AFCMState, X: chex.Array
) -> tuple[AFCMState, dict]:
"""Single AFCM iteration step."""
# Update T
T_new = self._update_T(X, state.centroids, state.gamma)
# Update U
U_new = self._update_U(X, state.centroids)
# Update centroids
centroids_new = self._compute_centroids(X, U_new, T_new)
# Recompute gamma with new U and centroids
gamma_new = self._compute_gamma(X, U_new, centroids_new)
# Compute objective
objective = self._compute_objective(X, U_new, T_new, centroids_new, gamma_new)
# Check convergence
centroid_change = jnp.linalg.norm(centroids_new - state.centroids, ord='fro')
converged = centroid_change <= self.epsilon
new_state = AFCMState(
centroids=centroids_new,
U=U_new,
T=T_new,
gamma=gamma_new,
objective=objective,
iteration=state.iteration + 1,
converged=converged
)
metrics = {
'objective': objective,
'centroid_change': centroid_change,
'converged': converged
}
return new_state, metrics
def _build_info(self, state, iteration):
labels = jnp.argmax(state.U, axis=1)
weights = jnp.max(state.U * state.T, axis=1)
return {
'centroids': state.centroids, 'labels': labels,
'weights': weights, 'iteration': iteration,
'objective': float(state.objective), 'max_iter': self.max_iter,
}
[docs]
def fit(self, X: chex.Array, initial_centroids=None, resume=False) -> Self:
"""Fit AFCM model to data."""
if resume and initial_centroids is not None:
raise ValueError("Cannot use both resume=True and initial_centroids")
X = self._validate_input(X)
if resume:
self._check_fitted()
centroids_init = self.centroids_
U_init = self._update_U(X, centroids_init)
gamma_init = self._compute_gamma(X, U_init, centroids_init)
T_init = self._update_T(X, centroids_init, gamma_init)
elif initial_centroids is not None:
centroids_init = self._validate_initial_centroids(X, initial_centroids)
U_init = self._update_U(X, centroids_init)
gamma_init = self._compute_gamma(X, U_init, centroids_init)
T_init = self._update_T(X, centroids_init, gamma_init)
else:
U_init, T_init, centroids_init = self._initialize(X)
gamma_init = self._compute_gamma(X, U_init, centroids_init)
initial_objective = self._compute_objective(X, U_init, T_init, centroids_init, gamma_init)
initial_state = AFCMState(
centroids=centroids_init, U=U_init, T=T_init, gamma=gamma_init,
objective=initial_objective, iteration=0, converged=False
)
final_state, self.history_ = self._run_training(X, initial_state)
# Store results
self.centroids_ = final_state.centroids
self.U_ = final_state.U
self.T_ = final_state.T
self.gamma_ = final_state.gamma
self.n_iter_ = int(final_state.iteration)
self.objective_ = float(final_state.objective)
self.objective_history_ = self.history_['objective']
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
"""Predict cluster labels for new data."""
self._check_fitted()
U = self._update_U(X, self.centroids_)
labels = jnp.argmax(U, axis=1)
return labels
[docs]
def predict_proba(self, X: chex.Array) -> chex.Array:
"""Predict fuzzy membership probabilities."""
self._check_fitted()
X = jnp.asarray(X)
U = self._update_U(X, self.centroids_)
return U
def get_typicality(self, X: chex.Array) -> chex.Array:
"""Compute typicality values."""
self._check_fitted()
X = jnp.asarray(X)
T = self._update_T(X, self.centroids_, self.gamma_)
return T