Source code for prosemble.core.distributed
"""Distributed training utilities for multi-device data parallelism.
Provides functions for sharding data across devices and replicating
model parameters, enabling data-parallel training on multi-GPU/TPU setups.
"""
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
[docs]
def create_mesh(devices=None):
"""Create a 1D device mesh for data parallelism.
Parameters
----------
devices : list of jax.Device or None
Devices to use. If None, returns None (single-device mode).
Returns
-------
Mesh or None
"""
if devices is None:
return None
return Mesh(devices, axis_names=('data',))
[docs]
def shard_data(X, y, mesh):
"""Shard data arrays along the batch dimension across devices.
Parameters
----------
X : jnp.ndarray of shape (n_samples, ...)
Input data.
y : jnp.ndarray of shape (n_samples,)
Labels.
mesh : Mesh
Device mesh from create_mesh().
Returns
-------
X_sharded, y_sharded : tuple of sharded arrays
"""
data_sharding = NamedSharding(mesh, P('data'))
X_sharded = jax.device_put(X, data_sharding)
y_sharded = jax.device_put(y, data_sharding)
return X_sharded, y_sharded
[docs]
def replicate_params(params, mesh):
"""Replicate params across all devices (no partitioning).
Parameters
----------
params : dict (pytree)
Model parameters.
mesh : Mesh
Device mesh from create_mesh().
Returns
-------
params replicated across mesh
"""
replicated = NamedSharding(mesh, P())
return jax.device_put(params, replicated)
[docs]
def replicate_opt_state(opt_state, mesh):
"""Replicate optimizer state across all devices.
Parameters
----------
opt_state : pytree
Optax optimizer state.
mesh : Mesh
Device mesh from create_mesh().
Returns
-------
opt_state replicated across mesh
"""
replicated = NamedSharding(mesh, P())
return jax.device_put(opt_state, replicated)
[docs]
def unshard_params(params):
"""Bring params back to a single device.
Used after training to store results as plain arrays for
predict/export operations.
Parameters
----------
params : pytree
Potentially sharded parameters.
Returns
-------
params on default device
"""
return jax.tree.map(
lambda x: jax.device_put(x, jax.devices()[0]), params
)