Source code for prosemble.datasets.dataset

"""
JAX-compatible dataset module.

Provides dataset loaders that return JAX arrays instead of NumPy arrays,
optimized for use with JAX-based clustering models.
"""

from dataclasses import dataclass

try:
    import jax.numpy as jnp
    from jax import Array as JaxArray
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    JaxArray = None

from sklearn.datasets import load_breast_cancer, load_iris, make_moons, make_blobs


[docs] @dataclass class DATASET_JAX: """ JAX-compatible dataset container. Attributes ---------- input_data : jax.numpy.ndarray Feature matrix as JAX array labels : jax.numpy.ndarray Labels as JAX array """ input_data: 'JaxArray' labels: 'JaxArray'
[docs] def to_numpy(self): """Convert JAX arrays back to NumPy arrays.""" import numpy as np return { 'input_data': np.array(self.input_data), 'labels': np.array(self.labels) }
[docs] def iris_dataset_jax(dtype=None) -> DATASET_JAX: """ Load Iris dataset as JAX arrays. Parameters ---------- dtype : jax.numpy.dtype, default=jnp.float32 Data type for features Returns ------- DATASET_JAX Dataset with JAX arrays (150 samples, 4 features, 3 classes) Examples -------- >>> from prosemble.datasets import iris_dataset_jax >>> dataset = iris_dataset_jax() >>> dataset.input_data.shape (150, 4) """ if not JAX_AVAILABLE: raise ImportError("JAX is required for JAX datasets. Install with: pip install jax") if dtype is None: dtype = jnp.float32 data, labels = load_iris(return_X_y=True) return DATASET_JAX( input_data=jnp.array(data, dtype=dtype), labels=jnp.array(labels, dtype=jnp.int32) )
[docs] def breast_cancer_dataset_jax(dtype=None) -> DATASET_JAX: """ Load Wisconsin Breast Cancer dataset as JAX arrays. Parameters ---------- dtype : jax.numpy.dtype, default=jnp.float32 Data type for features Returns ------- DATASET_JAX Dataset with JAX arrays (569 samples, 30 features, 2 classes) Examples -------- >>> from prosemble.datasets import breast_cancer_dataset_jax >>> dataset = breast_cancer_dataset_jax() >>> dataset.input_data.shape (569, 30) """ if not JAX_AVAILABLE: raise ImportError("JAX is required for JAX datasets. Install with: pip install jax") if dtype is None: dtype = jnp.float32 data, labels = load_breast_cancer(return_X_y=True) return DATASET_JAX( input_data=jnp.array(data, dtype=dtype), labels=jnp.array(labels, dtype=jnp.int32) )
[docs] def moons_dataset_jax( n_samples: int = 150, noise: float | None = None, random_state: int | None = None, dtype=None ) -> DATASET_JAX: """ Generate two interleaving half circles (moons) as JAX arrays. Parameters ---------- n_samples : int, default=150 Total number of points generated noise : float, optional Standard deviation of Gaussian noise added to data random_state : int, optional Random seed for reproducibility dtype : jax.numpy.dtype, default=jnp.float32 Data type for features Returns ------- DATASET_JAX Dataset with JAX arrays (n_samples, 2 features, 2 classes) Examples -------- >>> from prosemble.datasets import moons_dataset_jax >>> dataset = moons_dataset_jax(n_samples=200, noise=0.1, random_state=42) >>> dataset.input_data.shape (200, 2) """ if not JAX_AVAILABLE: raise ImportError("JAX is required for JAX datasets. Install with: pip install jax") if dtype is None: dtype = jnp.float32 data, labels = make_moons( n_samples=n_samples, shuffle=True, noise=noise, random_state=random_state ) return DATASET_JAX( input_data=jnp.array(data, dtype=dtype), labels=jnp.array(labels, dtype=jnp.int32) )
[docs] def blobs_dataset_jax( n_samples: list = None, centers: list = None, cluster_std: list = None, random_state: int | None = None, dtype=None ) -> DATASET_JAX: """ Generate isotropic Gaussian blobs as JAX arrays. Parameters ---------- n_samples : list, default=[120, 80] Number of samples per cluster centers : list, default=[[0.0, 0.0], [2.0, 2.0]] Centers of the clusters cluster_std : list, default=[1.2, 0.5] Standard deviation of clusters random_state : int, optional Random seed for reproducibility dtype : jax.numpy.dtype, default=jnp.float32 Data type for features Returns ------- DATASET_JAX Dataset with JAX arrays (sum(n_samples), 2 features) Examples -------- >>> from prosemble.datasets import blobs_dataset_jax >>> dataset = blobs_dataset_jax(random_state=42) >>> dataset.input_data.shape (200, 2) """ if not JAX_AVAILABLE: raise ImportError("JAX is required for JAX datasets. Install with: pip install jax") if dtype is None: dtype = jnp.float32 if n_samples is None: n_samples = [120, 80] if centers is None: centers = [[0.0, 0.0], [2.0, 2.0]] if cluster_std is None: cluster_std = [1.2, 0.5] data, labels = make_blobs( n_samples=n_samples, centers=centers, cluster_std=cluster_std, random_state=random_state, shuffle=False ) return DATASET_JAX( input_data=jnp.array(data, dtype=dtype), labels=jnp.array(labels, dtype=jnp.int32) )
[docs] @dataclass class DATA_JAX: """ JAX dataset collection with convenient property access. Parameters ---------- random : int, default=4 Random seed for dataset generation sample_size : int, default=1000 Default sample size (not currently used) dtype : jax.numpy.dtype, default=jnp.float32 Data type for features Examples -------- >>> from prosemble.datasets import DATA_JAX >>> data = DATA_JAX(random=42) >>> moons = data.S_1 # Moons dataset >>> blobs = data.S_2 # Blobs dataset >>> cancer = data.breast_cancer # Breast cancer dataset """ random: int = 4 sample_size: int = 1000 dtype = jnp.float32 if JAX_AVAILABLE else None @property def S_1(self) -> DATASET_JAX: """Moons dataset.""" return moons_dataset_jax( n_samples=150, noise=None, random_state=self.random, dtype=self.dtype ) @property def S_2(self) -> DATASET_JAX: """Blobs dataset.""" return blobs_dataset_jax( random_state=self.random, dtype=self.dtype ) @property def iris(self) -> DATASET_JAX: """Iris dataset.""" return iris_dataset_jax(dtype=self.dtype) @property def breast_cancer(self) -> DATASET_JAX: """Breast cancer dataset.""" return breast_cancer_dataset_jax(dtype=self.dtype) @property def moons(self) -> DATASET_JAX: """Alias for S_1 (moons dataset).""" return self.S_1 @property def blobs(self) -> DATASET_JAX: """Alias for S_2 (blobs dataset).""" return self.S_2
# Convenience functions for quick dataset loading
[docs] def load_iris_jax(dtype=None) -> DATASET_JAX: """Quick loader for iris dataset.""" return iris_dataset_jax(dtype=dtype)
[docs] def load_breast_cancer_jax(dtype=None) -> DATASET_JAX: """Quick loader for breast cancer dataset.""" return breast_cancer_dataset_jax(dtype=dtype)
[docs] def load_moons_jax(n_samples: int = 150, noise: float | None = None, random_state: int | None = None, dtype=None) -> DATASET_JAX: """Quick loader for moons dataset.""" return moons_dataset_jax(n_samples, noise, random_state, dtype)
[docs] def load_blobs_jax(random_state: int | None = None, dtype=None) -> DATASET_JAX: """Quick loader for blobs dataset.""" return blobs_dataset_jax(random_state=random_state, dtype=dtype)
# Backward-compatible alias DATA = DATA_JAX