Source code for prosemble.core.competitions
"""
Competition mechanisms for prototype-based classification.
These functions determine class predictions from distance matrices
and prototype labels using different strategies.
"""
import jax.numpy as jnp
from jax import jit
from functools import partial
[docs]
@jit
def wtac(distances, prototype_labels):
"""Winner-Takes-All Competition.
Assigns each sample the label of the closest prototype.
Parameters
----------
distances : array of shape (n_samples, n_prototypes)
Distance matrix.
prototype_labels : array of shape (n_prototypes,)
Class label for each prototype.
Returns
-------
array of shape (n_samples,)
Predicted labels.
"""
winning_indices = jnp.argmin(distances, axis=1)
return prototype_labels[winning_indices]
[docs]
@partial(jit, static_argnums=(2, 3))
def knnc(distances, prototype_labels, k=1, n_classes=None):
"""K-Nearest Neighbors Competition.
Assigns each sample the majority label among k closest prototypes.
Parameters
----------
distances : array of shape (n_samples, n_prototypes)
Distance matrix.
prototype_labels : array of shape (n_prototypes,)
Class label for each prototype.
k : int
Number of neighbors.
n_classes : int or None
Number of classes. If None, inferred from prototype_labels.
Returns
-------
array of shape (n_samples,)
Predicted labels.
"""
sorted_indices = jnp.argsort(distances, axis=1)[:, :k] # (n, k)
k_labels = prototype_labels[sorted_indices] # (n, k)
# Majority vote via one-hot counting
if n_classes is None:
n_classes = int(jnp.max(prototype_labels)) + 1
def _vote(labels_row):
one_hot = jnp.eye(n_classes, dtype=jnp.int32)[labels_row]
counts = jnp.sum(one_hot, axis=0)
return jnp.argmax(counts)
import jax
return jax.vmap(_vote)(k_labels)
[docs]
@jit
def cbcc(detections, reasonings):
"""Classification-By-Components Competition.
Computes class probability distributions using component detections
and reasoning matrices.
Parameters
----------
detections : array of shape (n_samples, n_components)
Similarity/detection scores for each component.
reasonings : array of shape (n_components, n_classes, 2)
Reasoning matrices. Last dim: [positive, negative_raw].
Returns
-------
array of shape (n_samples, n_classes)
Class probability distributions.
"""
# Extract positive and negative reasoning
# A = raw positive, B = raw negative factor
A = jnp.clip(reasonings[:, :, 0], 0.0, 1.0) # (n_comp, n_classes)
B = jnp.clip(reasonings[:, :, 1], 0.0, 1.0) # (n_comp, n_classes)
pk = A # positive reasoning
nk = (1.0 - A) * B # negative reasoning
# numerator: detections @ (pk - nk)^T + sum(nk, axis=0)
# detections: (n, n_comp), (pk - nk): (n_comp, n_classes)
numerator = detections @ (pk - nk) + jnp.sum(nk, axis=0)
# denominator: sum(pk + nk, axis=0) + epsilon
denominator = jnp.sum(pk + nk, axis=0) + 1e-8
probs = numerator / denominator
return probs