Source code for prosemble.models.crossentropy_lvq

"""
Cross-Entropy LVQ (CELVQ).

Uses cross-entropy loss on softmin of per-class minimum distances
instead of the GLVQ relative distance difference.

References
----------
.. [1] Villmann, T., et al. (2019). Analysis of variants of
       classification learning vector quantization by a stochastic
       setting.
"""

import jax.numpy as jnp

from prosemble.models.prototype_base import SupervisedPrototypeModel
from prosemble.core.losses import cross_entropy_lvq_loss


[docs] class CELVQ(SupervisedPrototypeModel): """Cross-Entropy Learning Vector Quantization. Computes per-class minimum distances, negates them to get logits, then applies cross-entropy loss against true labels. Parameters ---------- n_prototypes_per_class : int Prototypes per class. max_iter : int Maximum training iterations. lr : float Learning rate. See Also -------- SupervisedPrototypeModel : Full list of base parameters (optimizer, distance_fn, lr_scheduler, callbacks, patience, etc.). """ def _compute_loss(self, params, X, y, proto_labels): distances = self.distance_fn(X, params['prototypes']) return cross_entropy_lvq_loss(distances, y, proto_labels, self.n_classes_)