Source code for prosemble.core.data

"""
Data loading and batching utilities for prosemble.

Provides composable primitives for mini-batch iteration, suitable for
custom training loops.  The built-in ``fit()`` methods already handle
batching internally — these utilities are for advanced users who need
explicit control over data iteration.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp


[docs] def shuffle_arrays(key: jax.Array, *arrays: jnp.ndarray) -> tuple[jnp.ndarray, ...]: """Shuffle multiple arrays with the same random permutation. Parameters ---------- key : JAX PRNG key Random key for generating the permutation. *arrays : jnp.ndarray Arrays to shuffle. All must have the same length along axis 0. Returns ------- tuple of jnp.ndarray Shuffled arrays in the same order as the inputs. Examples -------- >>> key = jax.random.PRNGKey(0) >>> X = jnp.arange(12).reshape(4, 3) >>> y = jnp.array([0, 1, 0, 1]) >>> X_s, y_s = shuffle_arrays(key, X, y) """ if len(arrays) == 0: return () n = arrays[0].shape[0] perm = jax.random.permutation(key, n) return tuple(a[perm] for a in arrays)
[docs] def padded_batches( X: jnp.ndarray, y: jnp.ndarray | None = None, batch_size: int = 32, key: jax.Array | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray | None]: """Split data into static-shaped batches, padding the last batch. Returns arrays of shape ``(n_batches, batch_size, ...)`` suitable for ``jax.lax.scan``. If the data length is not divisible by ``batch_size``, the last batch is padded by repeating initial samples. Parameters ---------- X : jnp.ndarray of shape (n_samples, ...) Feature array. y : jnp.ndarray of shape (n_samples,), optional Label array. batch_size : int Number of samples per batch. key : JAX PRNG key, optional If provided, data is shuffled before batching. Returns ------- X_batches : jnp.ndarray of shape (n_batches, batch_size, ...) y_batches : jnp.ndarray of shape (n_batches, batch_size) or None Examples -------- >>> X = jnp.ones((10, 3)) >>> y = jnp.arange(10) >>> X_b, y_b = padded_batches(X, y, batch_size=4) >>> X_b.shape # (3, 4, 3) — 10 samples padded to 12 (3, 4, 3) """ n_samples = X.shape[0] if key is not None: if y is not None: X, y = shuffle_arrays(key, X, y) else: X, = shuffle_arrays(key, X) n_batches = (n_samples + batch_size - 1) // batch_size padded_size = n_batches * batch_size if padded_size > n_samples: pad_n = padded_size - n_samples X = jnp.concatenate([X, X[:pad_n]], axis=0) if y is not None: y = jnp.concatenate([y, y[:pad_n]], axis=0) X_batches = X.reshape(n_batches, batch_size, *X.shape[1:]) y_batches = y.reshape(n_batches, batch_size) if y is not None else None return X_batches, y_batches
[docs] def batched_iterator( X: jnp.ndarray, y: jnp.ndarray | None = None, batch_size: int = 32, shuffle: bool = True, key: jax.Array | None = None, drop_last: bool = False, ): """Yield mini-batches from data arrays. For use in custom Python training loops. Parameters ---------- X : jnp.ndarray of shape (n_samples, ...) Feature array. y : jnp.ndarray of shape (n_samples,), optional Label array. batch_size : int Number of samples per batch. shuffle : bool If True and *key* is provided, shuffle before iterating. key : JAX PRNG key, optional Required if ``shuffle=True``. drop_last : bool If True, drop the last batch when it is smaller than ``batch_size``. If False (default), the last batch may be smaller. Yields ------ X_batch : jnp.ndarray of shape (batch_size, ...) or smaller y_batch : jnp.ndarray of shape (batch_size,) or None Examples -------- >>> key = jax.random.PRNGKey(0) >>> X = jnp.ones((10, 3)) >>> for X_b, y_b in batched_iterator(X, batch_size=4, key=key): ... print(X_b.shape) (4, 3) (4, 3) (2, 3) """ n_samples = X.shape[0] if shuffle and key is not None: if y is not None: X, y = shuffle_arrays(key, X, y) else: X, = shuffle_arrays(key, X) for start in range(0, n_samples, batch_size): end = min(start + batch_size, n_samples) if drop_last and (end - start) < batch_size: break X_batch = X[start:end] y_batch = y[start:end] if y is not None else None yield X_batch, y_batch