Source code for prosemble.models.knn

"""
JAX implementation of K-Nearest Neighbors (KNN)

This is a GPU-accelerated implementation using JAX.
"""

# Author: Nana Abeka Otoo <abekaotoo@gmail.com>
# License: MIT

from typing import Self
from functools import partial

import chex
import jax
import jax.numpy as jnp
from jax import jit

from prosemble.core.distance import batch_squared_euclidean


[docs] class KNN: """ K-Nearest Neighbors (KNN) with JAX KNN is a simple, non-parametric classifier that predicts based on the k nearest training samples. Algorithm: 1. For each test sample, compute distances to all training samples, find k nearest neighbors, predict label as mode (most common) of k neighbors' labels, and compute probability as frequency of predicted label. Parameters ---------- n_neighbors : int, default=5 Number of neighbors to use """ def __init__(self, n_neighbors: int = 5): self.n_neighbors = n_neighbors # Fitted attributes self.X_train_ = None self.y_train_ = None self.n_classes_ = None
[docs] def fit(self, X: chex.Array, y: chex.Array) -> Self: """ Fit KNN model (store training data) Parameters ---------- X : array-like of shape (n_samples, n_features) Training data y : array-like of shape (n_samples,) Target labels Returns ------- self """ X = jnp.asarray(X) y = jnp.asarray(y) self.X_train_ = X self.y_train_ = y self.n_classes_ = len(jnp.unique(y)) return self
@partial(jit, static_argnums=(0,)) def _get_mode(self, labels: chex.Array) -> tuple[int, float]: """ Get mode (most common label) and its probability Parameters ---------- labels : array of shape (k,) Labels of k nearest neighbors Returns ------- mode_label : int Most common label probability : float Frequency of mode label (count / k) """ # Count occurrences of each label # Since we don't know n_classes at JIT time, we'll use a different approach # Get unique labels and their counts unique_labels, counts = jnp.unique(labels, return_counts=True, size=self.n_neighbors) # Find label with maximum count max_count_idx = jnp.argmax(counts) mode_label = unique_labels[max_count_idx] max_count = counts[max_count_idx] # Compute probability probability = max_count / labels.shape[0] return mode_label, probability @partial(jit, static_argnums=(0,)) def _predict_one(self, x: chex.Array) -> tuple[int, float]: """ Predict label and probability for a single sample Parameters ---------- x : array of shape (n_features,) Test sample Returns ------- label : int Predicted label probability : float Confidence of prediction """ # Compute distances to all training samples x_expanded = x[None, :] # (1, n_features) D_sq = batch_squared_euclidean(x_expanded, self.X_train_) # (1, n_train) D_sq = D_sq.squeeze(0) # (n_train,) # Find k nearest neighbors k_nearest_indices = jnp.argsort(D_sq)[:self.n_neighbors] # Get labels of k nearest neighbors k_nearest_labels = self.y_train_[k_nearest_indices] # Get mode and probability label, probability = self._get_mode(k_nearest_labels) return label, probability @partial(jit, static_argnums=(0,)) def _predict_labels(self, X: chex.Array) -> chex.Array: """Predict labels for multiple samples""" # Vectorize over samples predict_fn = lambda x: self._predict_one(x)[0] labels = jax.vmap(predict_fn)(X) return labels
[docs] def predict(self, X: chex.Array) -> chex.Array: """ Predict class labels for samples Parameters ---------- X : array-like of shape (n_samples, n_features) Data to predict Returns ------- labels : array of shape (n_samples,) Predicted class labels """ if self.X_train_ is None: raise ValueError("Model not fitted. Call fit() first.") X = jnp.asarray(X) return self._predict_labels(X)
@partial(jit, static_argnums=(0,)) def _get_probabilities(self, X: chex.Array) -> chex.Array: """Get prediction probabilities for multiple samples""" # Vectorize over samples predict_fn = lambda x: self._predict_one(x)[1] probabilities = jax.vmap(predict_fn)(X) return probabilities def get_proba(self, X: chex.Array) -> chex.Array: """ Get prediction probabilities (frequency of predicted label among k neighbors) Parameters ---------- X : array-like of shape (n_samples, n_features) Data to predict Returns ------- probabilities : array of shape (n_samples,) Confidence values (frequency of mode label) """ if self.X_train_ is None: raise ValueError("Model not fitted. Call fit() first.") X = jnp.asarray(X) return self._get_probabilities(X) @partial(jit, static_argnums=(0,)) def _get_distance_space(self, X: chex.Array) -> chex.Array: """ Get sorted distances to training samples For each test sample, return sorted distances to all training samples """ # Compute all distances D_sq = batch_squared_euclidean(X, self.X_train_) # (n_test, n_train) D = jnp.sqrt(jnp.maximum(D_sq, 1e-10)) # Sort distances for each sample D_sorted = jnp.sort(D, axis=1) return D_sorted def distance_space(self, X: chex.Array) -> chex.Array: """ Get distance space (sorted distances to all training samples) Parameters ---------- X : array-like of shape (n_samples, n_features) Data to compute distances for Returns ------- distances : array of shape (n_samples, n_train_samples) Sorted Euclidean distances to training samples """ if self.X_train_ is None: raise ValueError("Model not fitted. Call fit() first.") X = jnp.asarray(X) return self._get_distance_space(X) @partial(jit, static_argnums=(0,)) def _predict_proba_full(self, X: chex.Array) -> chex.Array: """ Predict class probabilities for all classes For each sample, return probability distribution over all classes based on k nearest neighbors. """ def compute_proba_one(x): # Compute distances x_expanded = x[None, :] D_sq = batch_squared_euclidean(x_expanded, self.X_train_) D_sq = D_sq.squeeze(0) # Find k nearest neighbors k_nearest_indices = jnp.argsort(D_sq)[:self.n_neighbors] k_nearest_labels = self.y_train_[k_nearest_indices] # Count votes for each class proba = jnp.zeros(self.n_classes_) for class_idx in range(self.n_classes_): count = jnp.sum(k_nearest_labels == class_idx) proba = proba.at[class_idx].set(count / self.n_neighbors) return proba probas = jax.vmap(compute_proba_one)(X) return probas def predict_proba(self, X: chex.Array) -> chex.Array: """ Predict class probabilities for samples Parameters ---------- X : array-like of shape (n_samples, n_features) Data to predict Returns ------- probabilities : array of shape (n_samples, n_classes) Class probability distributions """ if self.X_train_ is None: raise ValueError("Model not fitted. Call fit() first.") X = jnp.asarray(X) return self._predict_proba_full(X)