Source code for prosemble.core.quantization

"""Shared mixins for model base classes."""

import inspect

import jax.numpy as jnp


[docs] class MetadataCollectorMixin: """Mixin that auto-collects _hyperparams and _fitted_array_names from MRO. Any base class that declares per-class ``_hyperparams`` and ``_fitted_array_names`` tuples can inherit this mixin to get ``_all_hyperparams`` and ``_all_fitted_array_names`` aggregated automatically across the entire class hierarchy. """ _hyperparams: tuple[str, ...] = () _fitted_array_names: tuple[str, ...] = () def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) all_hp: list[str] = [] all_fa: list[str] = [] for klass in reversed(cls.__mro__): for name in getattr(klass, '_hyperparams', ()): if name not in all_hp: all_hp.append(name) for name in getattr(klass, '_fitted_array_names', ()): if name not in all_fa: all_fa.append(name) cls._all_hyperparams = tuple(all_hp) cls._all_fitted_array_names = tuple(all_fa)
[docs] def get_params(self, deep=True): """Get parameters for this estimator. Follows the sklearn estimator protocol by inspecting ``__init__`` signatures across the MRO. Parameters ---------- deep : bool, default=True Ignored (present for sklearn compatibility). Returns ------- dict Parameter names mapped to their values. """ params = {} for klass in type(self).__mro__: init = getattr(klass, '__init__', None) if init is None: continue sig = inspect.signature(init) for name, p in sig.parameters.items(): if name == 'self' or p.kind in ( p.VAR_POSITIONAL, p.VAR_KEYWORD, ): continue if name not in params: params[name] = getattr(self, name, p.default) return params
[docs] def set_params(self, **params): """Set parameters on this estimator. Parameters ---------- **params Estimator parameters to set. Returns ------- self """ valid = self.get_params() for key, value in params.items(): if key not in valid: raise ValueError( f"Invalid parameter '{key}' for {type(self).__name__}. " f"Valid parameters: {sorted(valid.keys())}" ) setattr(self, key, value) return self
[docs] class QuantizationMixin: """Mixin for quantizing/dequantizing fitted model parameters. Supports float16, bfloat16, and int8 (with per-tensor scale factors). Subclasses override ``_get_quantizable_attrs`` to declare which fitted attributes are eligible for quantization. """ _VALID_DTYPES = { 'float16': jnp.float16, 'bfloat16': jnp.bfloat16, 'int8': 'int8', } def _get_quantizable_attrs(self) -> list[str]: """Return attribute names of quantizable parameters. Subclasses override to return e.g. ``['centroids_']`` or ``['prototypes_']``. Default returns empty list. """ return []
[docs] def quantize(self, dtype='float16'): """Quantize model parameters to lower precision. Post-training quantization for smaller model size and faster inference. Parameters ---------- dtype : str Target precision: 'float16', 'bfloat16', or 'int8'. Returns ------- self """ self._check_fitted() if dtype not in self._VALID_DTYPES: raise ValueError( f"dtype must be one of {list(self._VALID_DTYPES.keys())}, got '{dtype}'" ) if dtype == 'int8': self._quantize_int8() else: target = self._VALID_DTYPES[dtype] self._quantize_float(target) self._quantized_dtype = dtype return self
[docs] def dequantize(self): """Restore model parameters to float32. Returns ------- self """ self._check_fitted() if not hasattr(self, '_quantized_dtype') or self._quantized_dtype is None: return self if self._quantized_dtype == 'int8': self._dequantize_int8() else: self._dequantize_float() self._quantized_dtype = None return self
@property def is_quantized(self) -> bool: """Whether model parameters are currently quantized.""" return getattr(self, '_quantized_dtype', None) is not None @property def quantized_dtype(self) -> str | None: """Current quantization dtype, or None if not quantized.""" return getattr(self, '_quantized_dtype', None) def _quantize_float(self, target_dtype): """Convert parameters to float16/bfloat16.""" for attr in self._get_quantizable_attrs(): val = getattr(self, attr) if val is not None: setattr(self, attr, val.astype(target_dtype)) def _quantize_int8(self): """Quantize parameters to int8 with per-tensor scale factors.""" self._int8_scales = {} for attr in self._get_quantizable_attrs(): val = getattr(self, attr) if val is not None: val_f32 = val.astype(jnp.float32) abs_max = jnp.max(jnp.abs(val_f32)) scale = abs_max / 127.0 quantized = jnp.round(val_f32 / (scale + 1e-10)).astype(jnp.int8) setattr(self, attr, quantized) self._int8_scales[attr] = scale def _dequantize_float(self): """Restore float16/bfloat16 parameters to float32.""" for attr in self._get_quantizable_attrs(): val = getattr(self, attr) if val is not None: setattr(self, attr, val.astype(jnp.float32)) def _dequantize_int8(self): """Restore int8 parameters to float32 using stored scales.""" scales = getattr(self, '_int8_scales', {}) for attr in self._get_quantizable_attrs(): val = getattr(self, attr) if val is not None and attr in scales: setattr(self, attr, val.astype(jnp.float32) * scales[attr]) self._int8_scales = {}