Source code for prosemble.core.model_selection

"""JAX-native model selection: clone, cross_val_score, GridSearchCV.

Pure JAX implementation — no sklearn dependency, no numpy roundtrips.
"""

import inspect
import itertools
from typing import Callable, Self

import jax.numpy as jnp
import chex

from prosemble.core.utils import k_fold_split_jax, accuracy_score_jax
from prosemble.core.pipeline import NotFittedError


[docs] def clone(estimator, **override_params): """Create a fresh (unfitted) copy of an estimator. Parameters ---------- estimator : object Estimator with ``get_params()`` protocol. **override_params Parameters to override in the clone. Returns ------- new_estimator : same type as estimator Fresh, unfitted instance. Examples -------- >>> model = GLVQ(n_prototypes_per_class=2, lr=0.01) >>> model2 = clone(model, lr=0.05) """ from prosemble.core.pipeline import Pipeline if isinstance(estimator, Pipeline): return _clone_pipeline(estimator, **override_params) if not hasattr(estimator, 'get_params'): raise TypeError( f"Cannot clone {type(estimator).__name__}: no get_params() method." ) params = estimator.get_params(deep=False) params.update(override_params) # Filter to params accepted by __init__ (including **kwargs) init_sig = inspect.signature(type(estimator).__init__) has_var_keyword = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in init_sig.parameters.values() ) if has_var_keyword: # Constructor accepts **kwargs, pass all params valid_params = {k: v for k, v in params.items() if k != 'self'} else: valid_params = { k: v for k, v in params.items() if k in init_sig.parameters and k != 'self' } return type(estimator)(**valid_params)
def _clone_pipeline(pipeline, **override_params): """Deep-clone a Pipeline, cloning each step's estimator.""" from prosemble.core.pipeline import Pipeline new_steps = [] for name, est in pipeline.steps: new_steps.append((name, clone(est))) new_pipe = Pipeline(new_steps) if override_params: new_pipe.set_params(**override_params) return new_pipe def _is_supervised(estimator): """Check if estimator's fit() accepts y parameter.""" sig = inspect.signature(estimator.fit) return 'y' in sig.parameters def _get_scorer(scoring): """Resolve scoring to a callable.""" if scoring == 'accuracy': return accuracy_score_jax elif callable(scoring): return scoring else: raise ValueError( f"Unknown scoring '{scoring}'. Use 'accuracy' or a callable." )
[docs] def cross_val_score( estimator, X: chex.Array, y: chex.Array | None = None, cv: int = 5, scoring: str | Callable = 'accuracy', random_seed: int = 42, ) -> chex.Array: """Evaluate estimator with cross-validation. Parameters ---------- estimator : object Estimator with fit/predict and get_params. X : array of shape (n_samples, n_features) y : array of shape (n_samples,), optional Required for supervised estimators. cv : int, default=5 Number of cross-validation folds. scoring : str or callable, default='accuracy' 'accuracy' or callable ``scorer(y_true, y_pred) -> float``. For unsupervised without y: ``scorer(estimator, X_test) -> float``. random_seed : int, default=42 Returns ------- scores : jnp.ndarray of shape (cv,) Score for each fold. Examples -------- >>> scores = cross_val_score(GLVQ(max_iter=30), X, y, cv=5) >>> print(f"Mean: {scores.mean():.3f} +/- {scores.std():.3f}") """ X = jnp.asarray(X) if y is not None: y = jnp.asarray(y) scorer = _get_scorer(scoring) supervised = _is_supervised(estimator) fold_scores = [] for train_idx, test_idx in k_fold_split_jax(X.shape[0], cv, random_seed): est = clone(estimator) X_train, X_test = X[train_idx], X[test_idx] if supervised: if y is None: raise ValueError("y must be provided for supervised estimators.") y_train, y_test = y[train_idx], y[test_idx] est.fit(X_train, y_train) y_pred = est.predict(X_test) score = float(scorer(y_test, y_pred)) else: est.fit(X_train) if y is not None: y_test = y[test_idx] y_pred = est.predict(X_test) score = float(scorer(y_test, y_pred)) else: # Unsupervised custom scorer: scorer(estimator, X_test) score = float(scorer(est, X_test)) fold_scores.append(score) return jnp.array(fold_scores)
[docs] class GridSearchCV: """Exhaustive search over a parameter grid with cross-validation. Parameters ---------- estimator : object Base estimator with fit/predict and get_params/set_params. Can be a Pipeline or any prosemble model. param_grid : dict Maps parameter names to lists of values to try. For Pipeline steps, use ``step_name__param`` notation. cv : int, default=5 Number of cross-validation folds. scoring : str or callable, default='accuracy' 'accuracy' or callable ``scorer(y_true, y_pred) -> float``. random_seed : int, default=42 refit : bool, default=True If True, refit the best model on the full dataset after search. verbose : int, default=0 0=silent, 1=per-combo summary, 2=per-fold detail. Attributes ---------- best_params_ : dict Parameters of the best model. best_score_ : float Mean CV score of the best model. best_estimator_ : object Fitted estimator with best params (only if refit=True). cv_results_ : dict Keys: 'params', 'mean_score', 'std_score', 'fold_scores', 'rank'. Examples -------- >>> gs = GridSearchCV( ... GLVQ(max_iter=30), ... {'n_prototypes_per_class': [1, 2], 'lr': [0.01, 0.05]}, ... cv=3, ... ) >>> gs.fit(X, y) >>> print(gs.best_params_, gs.best_score_) """ def __init__( self, estimator, param_grid: dict, cv: int = 5, scoring: str | Callable = 'accuracy', random_seed: int = 42, refit: bool = True, verbose: int = 0, ): self.estimator = estimator self.param_grid = param_grid self.cv = cv self.scoring = scoring self.random_seed = random_seed self.refit = refit self.verbose = verbose self.best_params_ = None self.best_score_ = None self.best_estimator_ = None self.cv_results_ = None def _generate_param_combinations(self): """Generate all combinations from param_grid.""" keys = sorted(self.param_grid.keys()) values = [self.param_grid[k] for k in keys] return [dict(zip(keys, vals)) for vals in itertools.product(*values)] def _clone_with_params(self, params): """Clone base estimator with overridden params.""" from prosemble.core.pipeline import Pipeline if isinstance(self.estimator, Pipeline): return _clone_pipeline(self.estimator, **params) return clone(self.estimator, **params)
[docs] def fit(self, X, y=None) -> Self: """Run grid search with cross-validation. Parameters ---------- X : array of shape (n_samples, n_features) y : array of shape (n_samples,), optional Returns ------- self """ X = jnp.asarray(X) if y is not None: y = jnp.asarray(y) param_combos = self._generate_param_combinations() all_results = [] for combo_idx, params in enumerate(param_combos): est = self._clone_with_params(params) scores = cross_val_score( est, X, y, cv=self.cv, scoring=self.scoring, random_seed=self.random_seed, ) mean_score = float(jnp.mean(scores)) std_score = float(jnp.std(scores)) all_results.append({ 'params': params, 'mean_score': mean_score, 'std_score': std_score, 'fold_scores': scores.tolist(), }) if self.verbose >= 1: print( f"[{combo_idx + 1}/{len(param_combos)}] " f"{params} -> {mean_score:.4f} +/- {std_score:.4f}" ) # Rank by mean score (descending) sorted_indices = sorted( range(len(all_results)), key=lambda i: all_results[i]['mean_score'], reverse=True, ) ranks = [0] * len(all_results) for rank, idx in enumerate(sorted_indices, 1): ranks[idx] = rank self.cv_results_ = { 'params': [r['params'] for r in all_results], 'mean_score': [r['mean_score'] for r in all_results], 'std_score': [r['std_score'] for r in all_results], 'fold_scores': [r['fold_scores'] for r in all_results], 'rank': ranks, } best_idx = sorted_indices[0] self.best_params_ = all_results[best_idx]['params'] self.best_score_ = all_results[best_idx]['mean_score'] if self.refit: self.best_estimator_ = self._clone_with_params(self.best_params_) if _is_supervised(self.best_estimator_): self.best_estimator_.fit(X, y) else: self.best_estimator_.fit(X) return self
[docs] def predict(self, X): """Predict using best estimator.""" if self.best_estimator_ is None: raise NotFittedError( "GridSearchCV not fitted or refit=False. " "Call fit() with refit=True first." ) return self.best_estimator_.predict(X)
[docs] def predict_proba(self, X): """Predict probabilities using best estimator.""" if self.best_estimator_ is None: raise NotFittedError( "GridSearchCV not fitted or refit=False." ) return self.best_estimator_.predict_proba(X)
def __repr__(self): return ( f"GridSearchCV(estimator={type(self.estimator).__name__}, " f"param_grid={self.param_grid}, cv={self.cv})" )