"""
JAX implementation of Noise Possibilistic C-Means (NPC)
This is a GPU-accelerated implementation using JAX.
"""
# Author: Nana Abeka Otoo <abekaotoo@gmail.com>
# License: MIT
from typing import Self
from functools import partial
import chex
import jax
import jax.numpy as jnp
from jax import jit
from prosemble.core.distance import batch_squared_euclidean
[docs]
class NPC:
"""
Noise Possibilistic C-Means (NPC) with JAX
NPC is a supervised prototype-based classifier that iteratively optimizes
prototypes based on accuracy metric. Uses softmin for probability estimation.
Algorithm:
1. Initialize prototypes (one per class)
2. Predict labels using nearest prototype
3. Compute accuracy
4. If accuracy >= threshold or max_iter reached, stop
5. Otherwise, recompute prototypes and repeat
Softmin function: :math:`\\text{softmin}(x_i) = \\exp(-x_i) / \\sum_j \\exp(-x_j)`
Parameters
----------
n_classes : int
Number of classes
max_iter : int, default=10
Maximum optimization steps
tol : float, default=0.8
Accuracy threshold for convergence
random_state : int, optional
Random seed for reproducibility
"""
def __init__(
self,
n_classes: int = 3,
max_iter: int = 10,
tol: float = 0.8,
random_state: int | None = None
):
self.n_classes = n_classes
self.max_iter = max_iter
self.tol = tol
self.random_state = random_state
# Fitted attributes
self.prototypes_ = None
self.n_iter_ = 0
def _compute_prototypes(self, X: chex.Array, y: chex.Array) -> chex.Array:
"""
Compute prototype for each class as the mean of samples in that class
Parameters
----------
X : array of shape (n_samples, n_features)
y : array of shape (n_samples,)
Returns
-------
prototypes : array of shape (n_classes, n_features)
"""
# Cannot JIT this due to boolean indexing
# Use vectorized approach instead
prototypes = []
for class_idx in range(self.n_classes):
# Get samples belonging to this class
mask = (y == class_idx)
class_samples = X[mask]
# Compute mean
if jnp.sum(mask) > 0:
prototype = jnp.mean(class_samples, axis=0)
else:
# If no samples, use random initialization
prototype = jnp.zeros(X.shape[1])
prototypes.append(prototype)
return jnp.array(prototypes)
@partial(jit, static_argnums=(0,))
def _predict_labels(self, X: chex.Array, prototypes: chex.Array) -> chex.Array:
"""Predict class labels using nearest prototype"""
D_sq = batch_squared_euclidean(X, prototypes)
labels = jnp.argmin(D_sq, axis=1)
return labels
@partial(jit, static_argnums=(0,))
def _compute_accuracy(self, y_true: chex.Array, y_pred: chex.Array) -> float:
"""Compute classification accuracy"""
return jnp.mean(y_true == y_pred)
[docs]
def fit(self, X: chex.Array, y: chex.Array) -> Self:
"""
Fit NPC model to labeled data
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data
y : array-like of shape (n_samples,)
Target labels (integers from 0 to n_classes-1)
Returns
-------
self
"""
X = jnp.asarray(X)
y = jnp.asarray(y)
# Initialize prototypes
prototypes = self._compute_prototypes(X, y)
# Optimization loop
for iteration in range(self.max_iter):
# Predict
y_pred = self._predict_labels(X, prototypes)
# Compute accuracy
accuracy = self._compute_accuracy(y, y_pred)
# Check convergence
if accuracy >= self.tol:
self.n_iter_ = iteration + 1
break
# Update prototypes
prototypes = self._compute_prototypes(X, y)
# Store iteration count
self.n_iter_ = iteration + 1
self.prototypes_ = prototypes
return self
[docs]
def predict(self, X: chex.Array) -> chex.Array:
"""
Predict class labels for samples
Parameters
----------
X : array-like of shape (n_samples, n_features)
Data to predict
Returns
-------
labels : array of shape (n_samples,)
Predicted class labels
"""
if self.prototypes_ is None:
raise ValueError("Model not fitted. Call fit() first.")
X = jnp.asarray(X)
return self._predict_labels(X, self.prototypes_)
@partial(jit, static_argnums=(0,))
def _softmin(self, x: chex.Array) -> chex.Array:
"""
Softmin function: :math:`\\text{softmin}(x_i) = \\exp(-x_i) / \\sum_j \\exp(-x_j)`.
Parameters
----------
x : array of shape (n_prototypes,)
Distance values
Returns
-------
probs : array of shape (n_prototypes,)
Softmin probabilities
"""
neg_x = -x
neg_x_shifted = neg_x - jnp.max(neg_x)
exp_neg_x = jnp.exp(neg_x_shifted)
return exp_neg_x / jnp.sum(exp_neg_x)
@partial(jit, static_argnums=(0,))
def _compute_distance_space(self, X: chex.Array) -> chex.Array:
"""
Compute distance matrix to all prototypes
Parameters
----------
X : array of shape (n_samples, n_features)
Returns
-------
distances : array of shape (n_samples, n_classes)
Euclidean distances to each prototype
"""
D_sq = batch_squared_euclidean(X, self.prototypes_)
D = jnp.sqrt(jnp.maximum(D_sq, 1e-10))
return D
@partial(jit, static_argnums=(0,))
def _predict_proba(self, X: chex.Array) -> chex.Array:
"""Compute class probabilities using softmin"""
distances = self._compute_distance_space(X)
# Apply softmin to each sample
probs = jax.vmap(self._softmin)(distances)
return probs
def predict_proba(self, X: chex.Array) -> chex.Array:
"""
Predict class probabilities for samples using softmin
Parameters
----------
X : array-like of shape (n_samples, n_features)
Data to predict
Returns
-------
probs : array of shape (n_samples, n_classes)
Class probabilities (softmin of distances)
"""
if self.prototypes_ is None:
raise ValueError("Model not fitted. Call fit() first.")
X = jnp.asarray(X)
return self._predict_proba(X)
def get_distance_space(self, X: chex.Array) -> chex.Array:
"""
Get distance space (Euclidean distances to prototypes)
Parameters
----------
X : array-like of shape (n_samples, n_features)
Data to compute distances for
Returns
-------
distances : array of shape (n_samples, n_classes)
Euclidean distances to each prototype
"""
if self.prototypes_ is None:
raise ValueError("Model not fitted. Call fit() first.")
X = jnp.asarray(X)
return self._compute_distance_space(X)