Source code for prosemble.core.activations
"""
Transfer/activation functions for prototype-based learning.
These functions are used to shape the GLVQ loss (mu values)
before summation, controlling the optimization landscape.
"""
import jax
import jax.numpy as jnp
from jax import jit
[docs]
@jit
def identity(x, beta=0.0):
"""Identity activation (passthrough).
Parameters
----------
x : array
Input values.
beta : float
Ignored. Present for API consistency.
Returns
-------
array
Same as input.
"""
return x
[docs]
@jit
def sigmoid_beta(x, beta=10.0):
"""Sigmoid activation with steepness parameter.
f(x) = 1 / (1 + exp(-beta * x))
Parameters
----------
x : array
Input values.
beta : float
Steepness parameter. Higher values give sharper transition.
Returns
-------
array
Sigmoid-transformed values in (0, 1).
"""
return jax.nn.sigmoid(beta * x)
[docs]
@jit
def swish_beta(x, beta=10.0):
"""Swish activation with steepness parameter.
f(x) = x * sigmoid(beta * x)
Parameters
----------
x : array
Input values.
beta : float
Steepness parameter.
Returns
-------
array
Swish-transformed values.
"""
return x * jax.nn.sigmoid(beta * x)
# Registry for name-based lookup
ACTIVATIONS = {
'identity': identity,
'sigmoid_beta': sigmoid_beta,
'swish_beta': swish_beta,
}
[docs]
def get_activation(name):
"""Get activation function by name.
Parameters
----------
name : str or callable
Name of activation ('identity', 'sigmoid_beta', 'swish_beta')
or a callable.
Returns
-------
callable
The activation function.
"""
if callable(name):
return name
if name in ACTIVATIONS:
return ACTIVATIONS[name]
raise ValueError(
f"Unknown activation '{name}'. "
f"Available: {list(ACTIVATIONS.keys())}"
)