Source code for prosemble.core.onnx_export

"""
ONNX export for prosemble prototype-based models.

Converts a fitted model's predict function into an ONNX graph.
Only supports models whose distance function can be expressed with
standard ONNX operators.  Unsupported models raise
``NotImplementedError`` with a clear message.

Supported distance functions:

- ``squared_euclidean_distance_matrix``
- ``euclidean_distance_matrix``
- ``manhattan_distance_matrix``
- ``omega_distance_matrix`` (global projection matrix)
- ``lomega_distance_matrix`` (per-prototype local matrices)
- ``tangent_distance_matrix`` (per-prototype tangent subspace)
- ``relevance_weighted`` (per-feature relevance weighting)

Supported decision patterns:

- WTAC (supervised classification)
- ArgMin (unsupervised clustering)
- One-class hard nearest (OCGLVQ family)
- One-class Gaussian soft (OCRSLVQ family)
- One-class Gaussian+NG soft (OCRSLVQ_NG family)
- SVQ-OCC response model (SVQOCC family)
- CBC reasoning (CBC, ImageCBC)
- PLVQ Gaussian mixture soft assignment

Supported encoder models:

- MLP encoder (SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN, PLVQ)
- CNN encoder (ImageGLVQ, ImageGMLVQ, ImageGTLVQ, ImageCBC)

Not supported:

- ``gaussian_kernel_matrix``, ``polynomial_kernel_matrix``
- Riemannian manifold distances (logm, expm have no ONNX equivalent)
"""

from __future__ import annotations

from functools import partial
from typing import Any

import numpy as np


def _check_onnx_installed():
    """Raise ImportError with clear message if onnx is not installed."""
    try:
        import onnx  # noqa: F401
    except ImportError:
        raise ImportError(
            "ONNX export requires the 'onnx' package. "
            "Install with: pip install prosemble[onnx]"
        )


# ---------------------------------------------------------------------------
# Numpy forward functions (for pre-computing latent prototypes at export time)
# ---------------------------------------------------------------------------

def _get_activation_np(name):
    """Return a numpy activation function by name."""
    if name == 'sigmoid':
        return lambda z: 1.0 / (1.0 + np.exp(-np.clip(z, -500, 500)))
    elif name == 'relu':
        return lambda z: np.maximum(0, z)
    elif name == 'tanh':
        return np.tanh
    elif name == 'leaky_relu':
        return lambda z: np.where(z > 0, z, 0.01 * z)
    elif name == 'selu':
        alpha = 1.6732632423543772
        scale = 1.0507009873554805
        return lambda z: scale * np.where(z > 0, z, alpha * (np.exp(z) - 1))
    else:
        raise ValueError(f"Unknown activation: {name}")


def _mlp_forward_np(params, x, activation='sigmoid'):
    """Numpy MLP forward pass (for pre-computing latent prototypes).

    Parameters
    ----------
    params : list of (weight, bias) tuples
        MLP parameters (JAX or numpy arrays).
    x : array of shape (n, d_in)
    activation : str

    Returns
    -------
    numpy array of shape (n, d_out)
    """
    act_fn = _get_activation_np(activation)
    x = np.asarray(x, dtype=np.float32)
    for w, b in params:
        x = act_fn(x @ np.asarray(w, dtype=np.float32)
                    + np.asarray(b, dtype=np.float32))
    return x


def _cnn_forward_np(params, x, activation='relu'):
    """Numpy CNN forward pass (for pre-computing latent prototypes).

    Parameters
    ----------
    params : dict with 'conv_layers' and 'linear'
    x : array of shape (N, H, W, C) — NHWC format
    activation : str

    Returns
    -------
    numpy array of shape (N, latent_dim)
    """
    act_fn = _get_activation_np(activation)
    x = np.asarray(x, dtype=np.float32)

    for kernel, bias in params['conv_layers']:
        kernel = np.asarray(kernel, dtype=np.float32)  # (kH, kW, C_in, C_out)
        bias = np.asarray(bias, dtype=np.float32)       # (C_out,)
        N, H, W, C_in = x.shape
        kH, kW = kernel.shape[:2]
        pH, pW = kH // 2, kW // 2
        # SAME padding
        x_pad = np.pad(x, ((0, 0), (pH, kH - 1 - pH),
                            (pW, kW - 1 - pW), (0, 0)))
        out = np.zeros((N, H, W, kernel.shape[3]), dtype=np.float32)
        for i in range(kH):
            for j in range(kW):
                out += np.einsum('nhwi,io->nhwo',
                                 x_pad[:, i:i + H, j:j + W, :],
                                 kernel[i, j])
        out += bias
        x = act_fn(out)

    # Global average pooling: (N, H, W, C) -> (N, C)
    x = np.mean(x, axis=(1, 2))

    # Linear head
    w, b = params['linear']
    x = act_fn(x @ np.asarray(w, dtype=np.float32)
               + np.asarray(b, dtype=np.float32))
    return x


# ---------------------------------------------------------------------------
# ONNX encoder builders
# ---------------------------------------------------------------------------

def _activation_node(input_name, output_name, activation):
    """Create a single ONNX activation node."""
    import onnx.helper as oh

    if activation == 'sigmoid':
        return oh.make_node('Sigmoid', [input_name], [output_name])
    elif activation == 'relu':
        return oh.make_node('Relu', [input_name], [output_name])
    elif activation == 'tanh':
        return oh.make_node('Tanh', [input_name], [output_name])
    elif activation == 'leaky_relu':
        return oh.make_node('LeakyRelu', [input_name], [output_name],
                            alpha=0.01)
    elif activation == 'selu':
        return oh.make_node('Selu', [input_name], [output_name])
    else:
        raise ValueError(f"Unknown activation for ONNX: {activation}")


def _mlp_encoder_onnx(input_name, params, activation):
    """Build ONNX nodes for an MLP encoder.

    Parameters
    ----------
    input_name : str
        Name of the input tensor (e.g. 'X').
    params : list of (weight, bias) tuples
    activation : str

    Returns
    -------
    nodes : list of onnx.NodeProto
    initializers : list of onnx.TensorProto
    output_name : str
        Name of the encoder's output tensor.
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []
    prev = input_name

    for i, (w, b) in enumerate(params):
        w_np = np.asarray(w, dtype=np.float32)
        b_np = np.asarray(b, dtype=np.float32)
        w_name = f'_enc_w_{i}'
        b_name = f'_enc_b_{i}'
        mm_name = f'_enc_mm_{i}'
        add_name = f'_enc_add_{i}'
        is_last = (i == len(params) - 1)
        act_name = '_enc_out' if is_last else f'_enc_act_{i}'

        inits.append(oh.make_tensor(
            w_name, TensorProto.FLOAT, list(w_np.shape),
            w_np.flatten().tolist(),
        ))
        inits.append(oh.make_tensor(
            b_name, TensorProto.FLOAT, list(b_np.shape),
            b_np.flatten().tolist(),
        ))
        nodes.append(oh.make_node('MatMul', [prev, w_name], [mm_name]))
        nodes.append(oh.make_node('Add', [mm_name, b_name], [add_name]))
        nodes.append(_activation_node(add_name, act_name, activation))
        prev = act_name

    return nodes, inits, '_enc_out'


def _cnn_encoder_onnx(input_name, params, input_shape, activation):
    """Build ONNX nodes for a CNN encoder.

    Parameters
    ----------
    input_name : str
        Name of the flat input tensor (batch, H*W*C).
    params : dict with 'conv_layers' and 'linear'
    input_shape : tuple
        (H, W, C) of the original images.
    activation : str

    Returns
    -------
    nodes : list of onnx.NodeProto
    initializers : list of onnx.TensorProto
    output_name : str
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    H, W, C = input_shape

    # Reshape flat input to (batch, H, W, C)
    nhwc_shape = np.array([-1, H, W, C], dtype=np.int64)
    inits.append(oh.make_tensor(
        '_enc_nhwc_shape', TensorProto.INT64, [4], nhwc_shape.tolist(),
    ))
    nodes.append(oh.make_node(
        'Reshape', [input_name, '_enc_nhwc_shape'], ['_enc_nhwc'],
    ))

    # Transpose NHWC -> NCHW for ONNX Conv
    nodes.append(oh.make_node(
        'Transpose', ['_enc_nhwc'], ['_enc_nchw'],
        perm=[0, 3, 1, 2],
    ))

    prev = '_enc_nchw'
    for i, (kernel, bias) in enumerate(params['conv_layers']):
        # JAX kernel: (kH, kW, C_in, C_out) -> ONNX: (C_out, C_in, kH, kW)
        k_np = np.asarray(kernel, dtype=np.float32).transpose(3, 2, 0, 1)
        b_np = np.asarray(bias, dtype=np.float32)
        k_name = f'_enc_ck_{i}'
        b_name = f'_enc_cb_{i}'
        conv_name = f'_enc_conv_{i}'
        act_name = f'_enc_conv_act_{i}'

        inits.append(oh.make_tensor(
            k_name, TensorProto.FLOAT, list(k_np.shape),
            k_np.flatten().tolist(),
        ))
        inits.append(oh.make_tensor(
            b_name, TensorProto.FLOAT, list(b_np.shape),
            b_np.flatten().tolist(),
        ))
        nodes.append(oh.make_node(
            'Conv', [prev, k_name, b_name], [conv_name],
            auto_pad='SAME_UPPER', strides=[1, 1],
        ))
        nodes.append(_activation_node(conv_name, act_name, activation))
        prev = act_name

    # GlobalAveragePool -> (N, C_last, 1, 1)
    nodes.append(oh.make_node(
        'GlobalAveragePool', [prev], ['_enc_gap'],
    ))

    # Flatten -> (N, C_last)
    nodes.append(oh.make_node(
        'Flatten', ['_enc_gap'], ['_enc_flat'],
        axis=1,
    ))

    # Linear head: MatMul + Add + Activation
    w_lin, b_lin = params['linear']
    w_np = np.asarray(w_lin, dtype=np.float32)
    b_np = np.asarray(b_lin, dtype=np.float32)
    inits.append(oh.make_tensor(
        '_enc_lin_w', TensorProto.FLOAT, list(w_np.shape),
        w_np.flatten().tolist(),
    ))
    inits.append(oh.make_tensor(
        '_enc_lin_b', TensorProto.FLOAT, list(b_np.shape),
        b_np.flatten().tolist(),
    ))
    nodes.append(oh.make_node('MatMul', ['_enc_flat', '_enc_lin_w'],
                              ['_enc_lin_mm']))
    nodes.append(oh.make_node('Add', ['_enc_lin_mm', '_enc_lin_b'],
                              ['_enc_lin_add']))
    nodes.append(_activation_node('_enc_lin_add', '_enc_out', activation))

    return nodes, inits, '_enc_out'


# ---------------------------------------------------------------------------
# Distance function -> ONNX subgraph builders
# ---------------------------------------------------------------------------

def _squared_euclidean_onnx(builder, X_name, proto_name):
    """Add squared Euclidean distance nodes: ||X - W||^2.

    Uses the expansion: ||x-y||^2 = ||x||^2 + ||y||^2 - 2*x*y^T
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # X_sq = sum(X**2, axis=1, keepdims=True)  -> (batch, 1)
    nodes.append(oh.make_node('Mul', [X_name, X_name], ['_x_sq_full']))
    nodes.append(oh.make_node(
        'ReduceSum', ['_x_sq_full', '_axis1_const'],
        ['_x_sq'], keepdims=1,
    ))

    # W_sq = sum(W**2, axis=1, keepdims=True)  -> (n_proto, 1)
    nodes.append(oh.make_node('Mul', [proto_name, proto_name], ['_w_sq_full']))
    nodes.append(oh.make_node(
        'ReduceSum', ['_w_sq_full', '_axis1_const'],
        ['_w_sq'], keepdims=1,
    ))

    # W_sq_T -> (1, n_proto)
    nodes.append(oh.make_node('Transpose', ['_w_sq'], ['_w_sq_t']))

    # XW = X @ W^T -> (batch, n_proto)
    nodes.append(oh.make_node('MatMul', [X_name, '_w_t'], ['_xw']))

    # D = X_sq + W_sq_T - 2*XW
    nodes.append(oh.make_node('Mul', ['_xw', '_two_const'], ['_2xw']))
    nodes.append(oh.make_node('Add', ['_x_sq', '_w_sq_t'], ['_xsq_wsq']))
    nodes.append(oh.make_node('Sub', ['_xsq_wsq', '_2xw'], ['_dist_raw']))

    # Clip to >= 0
    nodes.append(oh.make_node('Relu', ['_dist_raw'], ['distances']))

    # Need: W^T, constants
    extra_initializers = [
        oh.make_tensor('_two_const', TensorProto.FLOAT, [], [2.0]),
        oh.make_tensor('_axis1_const', TensorProto.INT64, [1], [1]),
    ]
    # W^T computed as a Transpose node
    nodes.insert(0, oh.make_node('Transpose', [proto_name], ['_w_t']))

    return nodes, extra_initializers, 'distances'


def _euclidean_onnx(builder, X_name, proto_name):
    """Euclidean distance = sqrt(squared_euclidean)."""
    import onnx.helper as oh

    nodes, inits, dist_name = _squared_euclidean_onnx(builder, X_name, proto_name)
    nodes.append(oh.make_node('Sqrt', [dist_name], ['distances_eucl']))
    return nodes, inits, 'distances_eucl'


def _manhattan_onnx(builder, X_name, proto_name):
    """Manhattan distance: sum(|X - W|, axis=-1)."""
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Unsqueeze X -> (batch, 1, features)
    nodes.append(oh.make_node('Unsqueeze', [X_name, '_axis1_const'], ['_x_exp']))
    # Unsqueeze W -> (1, n_proto, features)
    nodes.append(oh.make_node('Unsqueeze', [proto_name, '_axis0_const'], ['_w_exp']))

    # |X - W|
    nodes.append(oh.make_node('Sub', ['_x_exp', '_w_exp'], ['_diff']))
    nodes.append(oh.make_node('Abs', ['_diff'], ['_abs_diff']))

    # Sum over features axis=2
    nodes.append(oh.make_node(
        'ReduceSum', ['_abs_diff', '_axis2_const'],
        ['distances_manh'], keepdims=0,
    ))

    extra_initializers = [
        oh.make_tensor('_axis0_const', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_axis1_const', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_axis2_const', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_manh'


def _omega_onnx(builder, X_name, proto_name, omega_name):
    """Omega distance: ||X@omega - W@omega||^2."""
    import onnx.helper as oh

    nodes = []
    # Project: X_proj = X @ omega, W_proj = W @ omega
    x_proj = '_om_x_proj'
    w_proj = '_om_w_proj'
    nodes.append(oh.make_node('MatMul', [X_name, omega_name], [x_proj]))
    nodes.append(oh.make_node('MatMul', [proto_name, omega_name], [w_proj]))

    # Squared euclidean on projected space
    sq_nodes, sq_inits, sq_dist = _squared_euclidean_onnx(
        builder, x_proj, w_proj
    )
    # Rename all internal names to avoid collision with top-level graph,
    # but preserve references to the projection outputs.
    preserve = {x_proj, w_proj}
    for n in sq_nodes:
        for i, out in enumerate(n.output):
            n.output[i] = '_om' + out
        for i, inp in enumerate(n.input):
            if inp in preserve:
                pass  # keep as-is
            elif inp.startswith('_'):
                n.input[i] = '_om' + inp

    # Rename initializer tensor names to match renamed node inputs
    renamed_inits = []
    for init in sq_inits:
        new_name = '_om' + init.name
        renamed_inits.append(oh.make_tensor(
            new_name, init.data_type, list(init.dims),
            list(init.int64_data or init.float_data),
        ))

    nodes.extend(sq_nodes)
    out_name = '_om' + sq_dist

    return nodes, renamed_inits, out_name


def _relevance_weighted_onnx(builder, X_name, proto_name, relevances_name):
    r"""Relevance-weighted squared Euclidean: sum(lambda_j * (x_j - w_j)^2).

    Parameters
    ----------
    relevances_name : str
        Name of the (d,) relevance vector initializer.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Unsqueeze X -> (batch, 1, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [X_name, '_rw_axis1'], ['_rw_x_exp'],
    ))
    # Unsqueeze W -> (1, n_proto, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_name, '_rw_axis0'], ['_rw_w_exp'],
    ))

    # diff = X - W -> (batch, n_proto, features)
    nodes.append(oh.make_node('Sub', ['_rw_x_exp', '_rw_w_exp'], ['_rw_diff']))

    # diff^2
    nodes.append(oh.make_node('Mul', ['_rw_diff', '_rw_diff'], ['_rw_diff_sq']))

    # weighted = relevances * diff^2  (relevances broadcasts as (1, 1, d))
    nodes.append(oh.make_node(
        'Mul', [relevances_name, '_rw_diff_sq'], ['_rw_weighted'],
    ))

    # sum over features axis=2
    nodes.append(oh.make_node(
        'ReduceSum', ['_rw_weighted', '_rw_axis2'],
        ['distances_rw'], keepdims=0,
    ))

    extra_initializers = [
        oh.make_tensor('_rw_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_rw_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_rw_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_rw'


def _local_omega_onnx(builder, X_name, proto_name, omegas_name):
    r"""Local omega distance: ||Omega_k (x - w_k)||^2 for each prototype k.

    Uses ONNX batched MatMul: (p, n, d) @ (p, d, l) -> (p, n, l).

    Parameters
    ----------
    omegas_name : str
        Name of the (p, d, l) per-prototype omega matrices initializer.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Unsqueeze X -> (batch, 1, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [X_name, '_lo_axis1'], ['_lo_x_exp'],
    ))
    # Unsqueeze W -> (1, n_proto, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_name, '_lo_axis0'], ['_lo_w_exp'],
    ))

    # diff = X - W -> (batch, n_proto, features)
    nodes.append(oh.make_node(
        'Sub', ['_lo_x_exp', '_lo_w_exp'], ['_lo_diff'],
    ))

    # Transpose diff -> (n_proto, batch, features) for batched MatMul
    nodes.append(oh.make_node(
        'Transpose', ['_lo_diff'], ['_lo_diff_t'],
        perm=[1, 0, 2],
    ))

    # Batched MatMul: (p, n, d) @ (p, d, l) -> (p, n, l)
    nodes.append(oh.make_node(
        'MatMul', ['_lo_diff_t', omegas_name], ['_lo_projected'],
    ))

    # projected^2
    nodes.append(oh.make_node(
        'Mul', ['_lo_projected', '_lo_projected'], ['_lo_proj_sq'],
    ))

    # ReduceSum over latent axis=2 -> (p, n)
    nodes.append(oh.make_node(
        'ReduceSum', ['_lo_proj_sq', '_lo_axis2'],
        ['_lo_dist_t'], keepdims=0,
    ))

    # Transpose -> (n, p) = (batch, n_proto)
    nodes.append(oh.make_node(
        'Transpose', ['_lo_dist_t'], ['distances_lo'],
        perm=[1, 0],
    ))

    extra_initializers = [
        oh.make_tensor('_lo_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_lo_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_lo_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_lo'


def _tangent_onnx(builder, X_name, proto_name, omegas_name):
    r"""Tangent distance: ||(I - Omega_k Omega_k^T)(x - w_k)||^2.

    Computes the squared norm of the component of (x - w_k) orthogonal
    to prototype k's tangent subspace.

    Uses ONNX batched MatMul:
      proj = (p,n,d) @ (p,d,s) -> (p,n,s)
      recon = (p,n,s) @ (p,s,d) -> (p,n,d)
      tang_diff = diff - recon

    Parameters
    ----------
    omegas_name : str
        Name of the (p, d, s) per-prototype orthonormal subspace bases.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Unsqueeze X -> (batch, 1, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [X_name, '_tg_axis1'], ['_tg_x_exp'],
    ))
    # Unsqueeze W -> (1, n_proto, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_name, '_tg_axis0'], ['_tg_w_exp'],
    ))

    # diff = X - W -> (batch, n_proto, features)
    nodes.append(oh.make_node(
        'Sub', ['_tg_x_exp', '_tg_w_exp'], ['_tg_diff'],
    ))

    # Transpose diff -> (n_proto, batch, features)
    nodes.append(oh.make_node(
        'Transpose', ['_tg_diff'], ['_tg_diff_t'],
        perm=[1, 0, 2],
    ))

    # Step 1: Project onto subspace
    # proj = (p, n, d) @ (p, d, s) -> (p, n, s)
    nodes.append(oh.make_node(
        'MatMul', ['_tg_diff_t', omegas_name], ['_tg_proj'],
    ))

    # Step 2: Transpose omegas -> (p, s, d) for reconstruction
    nodes.append(oh.make_node(
        'Transpose', [omegas_name], ['_tg_omegas_T'],
        perm=[0, 2, 1],
    ))

    # recon = (p, n, s) @ (p, s, d) -> (p, n, d)
    nodes.append(oh.make_node(
        'MatMul', ['_tg_proj', '_tg_omegas_T'], ['_tg_recon'],
    ))

    # Step 3: tang_diff = diff - recon (orthogonal complement)
    nodes.append(oh.make_node(
        'Sub', ['_tg_diff_t', '_tg_recon'], ['_tg_tang_diff'],
    ))

    # tang_diff^2
    nodes.append(oh.make_node(
        'Mul', ['_tg_tang_diff', '_tg_tang_diff'], ['_tg_tang_sq'],
    ))

    # ReduceSum over features axis=2 -> (p, n)
    nodes.append(oh.make_node(
        'ReduceSum', ['_tg_tang_sq', '_tg_axis2'],
        ['_tg_dist_t'], keepdims=0,
    ))

    # Transpose -> (n, p) = (batch, n_proto)
    nodes.append(oh.make_node(
        'Transpose', ['_tg_dist_t'], ['distances_tg'],
        perm=[1, 0],
    ))

    extra_initializers = [
        oh.make_tensor('_tg_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_tg_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_tg_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_tg'


# ---------------------------------------------------------------------------
# Differentiating Kernel distance builders
# ---------------------------------------------------------------------------

def _kernel_per_proto_onnx(builder, X_name, proto_name, sigmas_name):
    r"""Gaussian kernel distance with per-prototype bandwidth.

    .. math::

        d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(
            -\frac{\|x - w_k\|^2}{2\sigma_k^2}
        \right)\right)

    Parameters
    ----------
    sigmas_name : str
        Name of the (n_proto,) per-prototype bandwidth initializer.
        Should be pre-clamped to sigma_min at export time.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto), values in [0, 2].
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Unsqueeze X -> (batch, 1, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [X_name, '_kpp_axis1'], ['_kpp_x_exp'],
    ))
    # Unsqueeze W -> (1, n_proto, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_name, '_kpp_axis0'], ['_kpp_w_exp'],
    ))

    # diff = X - W -> (batch, n_proto, features)
    nodes.append(oh.make_node('Sub', ['_kpp_x_exp', '_kpp_w_exp'], ['_kpp_diff']))

    # diff^2
    nodes.append(oh.make_node('Mul', ['_kpp_diff', '_kpp_diff'], ['_kpp_diff_sq']))

    # sq_norms = sum(diff^2, axis=2) -> (batch, n_proto)
    nodes.append(oh.make_node(
        'ReduceSum', ['_kpp_diff_sq', '_kpp_axis2'],
        ['_kpp_sq_norms'], keepdims=0,
    ))

    # scaled = sq_norms * neg_inv_2sigma_sq
    # neg_inv_2sigma_sq is (n_proto,), broadcasts with (batch, n_proto)
    nodes.append(oh.make_node(
        'Mul', ['_kpp_sq_norms', sigmas_name], ['_kpp_scaled'],
    ))

    # K = exp(scaled)
    nodes.append(oh.make_node('Exp', ['_kpp_scaled'], ['_kpp_K']))

    # 1 - K
    nodes.append(oh.make_node('Sub', ['_kpp_one', '_kpp_K'], ['_kpp_one_minus_K']))

    # 2 * (1 - K)
    nodes.append(oh.make_node(
        'Mul', ['_kpp_two', '_kpp_one_minus_K'], ['distances_kpp'],
    ))

    extra_initializers = [
        oh.make_tensor('_kpp_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_kpp_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_kpp_axis2', TensorProto.INT64, [1], [2]),
        oh.make_tensor('_kpp_one', TensorProto.FLOAT, [], [1.0]),
        oh.make_tensor('_kpp_two', TensorProto.FLOAT, [], [2.0]),
    ]
    return nodes, extra_initializers, 'distances_kpp'


def _kernel_relevance_onnx(builder, X_name, proto_name, sigmas_name,
                           relevances_name):
    r"""Relevance-weighted Gaussian kernel distance.

    .. math::

        d_\kappa^2(x, w_k) = 2\left(1 - \exp\left(
            -\frac{\sum_j \lambda_j (x_j - w_{kj})^2}{2\sigma_k^2}
        \right)\right)

    where :math:`\lambda = \text{softmax}(\text{relevances})`.

    Parameters
    ----------
    sigmas_name : str
        Name of the (n_proto,) per-prototype bandwidth initializer.
    relevances_name : str
        Name of the (n_features,) raw relevance logits initializer.
        Softmax is applied inside this function.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto), values in [0, 2].
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Softmax on raw relevances -> normalized lambdas
    nodes.append(oh.make_node(
        'Softmax', [relevances_name], ['_kr_lambdas'], axis=0,
    ))

    # Unsqueeze X -> (batch, 1, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [X_name, '_kr_axis1'], ['_kr_x_exp'],
    ))
    # Unsqueeze W -> (1, n_proto, features)
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_name, '_kr_axis0'], ['_kr_w_exp'],
    ))

    # diff = X - W -> (batch, n_proto, features)
    nodes.append(oh.make_node('Sub', ['_kr_x_exp', '_kr_w_exp'], ['_kr_diff']))

    # diff^2
    nodes.append(oh.make_node('Mul', ['_kr_diff', '_kr_diff'], ['_kr_diff_sq']))

    # weighted = lambdas * diff^2  (lambdas broadcasts as (d,))
    nodes.append(oh.make_node(
        'Mul', ['_kr_lambdas', '_kr_diff_sq'], ['_kr_weighted'],
    ))

    # weighted_norms = sum(weighted, axis=2) -> (batch, n_proto)
    nodes.append(oh.make_node(
        'ReduceSum', ['_kr_weighted', '_kr_axis2'],
        ['_kr_weighted_norms'], keepdims=0,
    ))

    # scaled = weighted_norms * neg_inv_2sigma_sq
    nodes.append(oh.make_node(
        'Mul', ['_kr_weighted_norms', sigmas_name], ['_kr_scaled'],
    ))

    # K = exp(scaled)
    nodes.append(oh.make_node('Exp', ['_kr_scaled'], ['_kr_K']))

    # 1 - K
    nodes.append(oh.make_node('Sub', ['_kr_one', '_kr_K'], ['_kr_one_minus_K']))

    # 2 * (1 - K)
    nodes.append(oh.make_node(
        'Mul', ['_kr_two', '_kr_one_minus_K'], ['distances_kr'],
    ))

    extra_initializers = [
        oh.make_tensor('_kr_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_kr_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_kr_axis2', TensorProto.INT64, [1], [2]),
        oh.make_tensor('_kr_one', TensorProto.FLOAT, [], [1.0]),
        oh.make_tensor('_kr_two', TensorProto.FLOAT, [], [2.0]),
    ]
    return nodes, extra_initializers, 'distances_kr'


def _kernel_exponential_onnx(builder, X_name, proto_name, omega_hat_name):
    r"""Exponential kernel distance with learnable transformation.

    .. math::

        \hat\Lambda = \hat\Omega \hat\Omega^T

        d_\kappa^2(x, w) = \exp(x^T \hat\Lambda x)
                          + \exp(w^T \hat\Lambda w)
                          - 2 \exp(x^T \hat\Lambda w)

    Parameters
    ----------
    omega_hat_name : str
        Name of the (n_features, latent_dim) transformation matrix
        initializer.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto).
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Lambda_hat = omega_hat @ omega_hat^T  -> (d, d)
    nodes.append(oh.make_node(
        'Transpose', [omega_hat_name], ['_ke_omega_hat_T'],
    ))
    nodes.append(oh.make_node(
        'MatMul', [omega_hat_name, '_ke_omega_hat_T'], ['_ke_lambda_hat'],
    ))

    # XL = X @ Lambda_hat -> (batch, d)
    nodes.append(oh.make_node(
        'MatMul', [X_name, '_ke_lambda_hat'], ['_ke_XL'],
    ))

    # xLx = sum(X * XL, axis=1, keepdims=1) -> (batch, 1)
    nodes.append(oh.make_node('Mul', [X_name, '_ke_XL'], ['_ke_X_XL']))
    nodes.append(oh.make_node(
        'ReduceSum', ['_ke_X_XL', '_ke_axis1'],
        ['_ke_xLx'], keepdims=1,
    ))

    # WL = W @ Lambda_hat -> (n_proto, d)
    nodes.append(oh.make_node(
        'MatMul', [proto_name, '_ke_lambda_hat'], ['_ke_WL'],
    ))

    # wLw = sum(W * WL, axis=1, keepdims=1) -> (n_proto, 1)
    nodes.append(oh.make_node('Mul', [proto_name, '_ke_WL'], ['_ke_W_WL']))
    nodes.append(oh.make_node(
        'ReduceSum', ['_ke_W_WL', '_ke_axis1'],
        ['_ke_wLw'], keepdims=1,
    ))

    # wLw_T -> (1, n_proto)
    nodes.append(oh.make_node('Transpose', ['_ke_wLw'], ['_ke_wLw_T']))

    # W^T -> (d, n_proto)
    nodes.append(oh.make_node('Transpose', [proto_name], ['_ke_W_T']))

    # xLw = XL @ W^T -> (batch, n_proto)
    nodes.append(oh.make_node('MatMul', ['_ke_XL', '_ke_W_T'], ['_ke_xLw']))

    # exp terms
    nodes.append(oh.make_node('Exp', ['_ke_xLx'], ['_ke_exp_xLx']))
    nodes.append(oh.make_node('Exp', ['_ke_wLw_T'], ['_ke_exp_wLw']))
    nodes.append(oh.make_node('Exp', ['_ke_xLw'], ['_ke_exp_xLw']))

    # 2 * exp(xLw)
    nodes.append(oh.make_node(
        'Mul', ['_ke_two', '_ke_exp_xLw'], ['_ke_2exp_xLw'],
    ))

    # exp(xLx) + exp(wLw)  (broadcasts: (batch,1) + (1,n_proto) -> (batch,n_proto))
    nodes.append(oh.make_node(
        'Add', ['_ke_exp_xLx', '_ke_exp_wLw'], ['_ke_sum_self'],
    ))

    # dist_raw = sum_self - 2*exp(xLw)
    nodes.append(oh.make_node(
        'Sub', ['_ke_sum_self', '_ke_2exp_xLw'], ['_ke_dist_raw'],
    ))

    # Clip negatives (numerical stability)
    nodes.append(oh.make_node('Relu', ['_ke_dist_raw'], ['distances_ke']))

    extra_initializers = [
        oh.make_tensor('_ke_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_ke_two', TensorProto.FLOAT, [], [2.0]),
    ]
    return nodes, extra_initializers, 'distances_ke'


# ---------------------------------------------------------------------------
# Riemannian model builders
# ---------------------------------------------------------------------------


def _riemannian_so_chordal_onnx(X_name, proto_name, n):
    r"""Chordal distance on SO(n): d^2(R,S) = ||R - S||^2_F.

    Input X and prototypes are flattened (batch, n*n) and (p, n*n).
    Reshapes to 3D, broadcasts, computes Frobenius distance matrix.

    Returns
    -------
    nodes, initializers, output_name
        Output shape: (batch, n_proto)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Unsqueeze X -> (batch, 1, n*n)
    nodes.append(oh.make_node(
        'Unsqueeze', [X_name, '_rsc_axis1'], ['_rsc_x_exp'],
    ))
    # Unsqueeze W -> (1, p, n*n)
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_name, '_rsc_axis0'], ['_rsc_w_exp'],
    ))

    # diff = X - W -> (batch, p, n*n)
    nodes.append(oh.make_node(
        'Sub', ['_rsc_x_exp', '_rsc_w_exp'], ['_rsc_diff'],
    ))

    # diff^2
    nodes.append(oh.make_node(
        'Mul', ['_rsc_diff', '_rsc_diff'], ['_rsc_diff_sq'],
    ))

    # ReduceSum over last axis -> (batch, p)
    nodes.append(oh.make_node(
        'ReduceSum', ['_rsc_diff_sq', '_rsc_axis2'],
        ['distances_rsc'], keepdims=0,
    ))

    extra_initializers = [
        oh.make_tensor('_rsc_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_rsc_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_rsc_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_rsc'


def _riemannian_so_tangent_onnx(X_name, proto_name, n):
    r"""Compute SO(n) tangent vectors: Log_W(X) = W @ skew(W^T @ X).

    skew(A) = (A - A^T) / 2

    Input X: (batch, n*n), prototypes W: (p, n*n).
    Output tangent: (p, batch, n*n) — ready for downstream metric.

    Returns
    -------
    nodes, initializers, output_name
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Reshape X -> (batch, n, n)
    nodes.append(oh.make_node(
        'Reshape', [X_name, '_rst_x_shape'], ['_rst_X3'],
    ))
    # Reshape W -> (p, n, n)
    nodes.append(oh.make_node(
        'Reshape', [proto_name, '_rst_w_shape'], ['_rst_W3'],
    ))

    # Unsqueeze X -> (batch, 1, n, n)
    nodes.append(oh.make_node(
        'Unsqueeze', ['_rst_X3', '_rst_axis1'], ['_rst_X4'],
    ))
    # Unsqueeze W -> (1, p, n, n)
    nodes.append(oh.make_node(
        'Unsqueeze', ['_rst_W3', '_rst_axis0'], ['_rst_W4'],
    ))

    # W^T: transpose last two dims of W -> (1, p, n, n)
    nodes.append(oh.make_node(
        'Transpose', ['_rst_W4'], ['_rst_Wt'],
        perm=[0, 1, 3, 2],
    ))

    # RtS = W^T @ X -> (batch, p, n, n) via broadcasting
    nodes.append(oh.make_node(
        'MatMul', ['_rst_Wt', '_rst_X4'], ['_rst_RtS'],
    ))

    # RtS^T: transpose last two dims
    nodes.append(oh.make_node(
        'Transpose', ['_rst_RtS'], ['_rst_RtS_T'],
        perm=[0, 1, 3, 2],
    ))

    # skew = (RtS - RtS^T) / 2
    nodes.append(oh.make_node(
        'Sub', ['_rst_RtS', '_rst_RtS_T'], ['_rst_skew_raw'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_rst_skew_raw', '_rst_two'], ['_rst_skew'],
    ))

    # tangent = W @ skew -> (batch, p, n, n)
    nodes.append(oh.make_node(
        'MatMul', ['_rst_W4', '_rst_skew'], ['_rst_tangent4d'],
    ))

    # Reshape tangent -> (batch, p, n*n)
    nodes.append(oh.make_node(
        'Reshape', ['_rst_tangent4d', '_rst_tang_shape'], ['_rst_tangent3d'],
    ))

    # Transpose -> (p, batch, n*n) for downstream metric ops
    nodes.append(oh.make_node(
        'Transpose', ['_rst_tangent3d'], ['tangent_so'],
        perm=[1, 0, 2],
    ))

    extra_initializers = [
        oh.make_tensor('_rst_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_rst_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_rst_x_shape', TensorProto.INT64, [3], [-1, n, n]),
        oh.make_tensor('_rst_w_shape', TensorProto.INT64, [3], [-1, n, n]),
        oh.make_tensor('_rst_tang_shape', TensorProto.INT64, [3],
                       [0, 0, n * n]),
        oh.make_tensor('_rst_two', TensorProto.FLOAT, [], [2.0]),
    ]
    return nodes, extra_initializers, 'tangent_so'


def _riemannian_gr_tangent_onnx(X_name, proto_name, n, k):
    r"""Compute Grassmannian tangent vectors: Log_{Q1}(Q2) = Q2 - Q1(Q1^T Q2).

    Input X: (batch, n*k), prototypes W: (p, n*k).
    Output tangent: (p, batch, n*k) — ready for downstream metric.

    Returns
    -------
    nodes, initializers, output_name
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Reshape X -> (batch, n, k)
    nodes.append(oh.make_node(
        'Reshape', [X_name, '_rgt_x_shape'], ['_rgt_X3'],
    ))
    # Reshape W -> (p, n, k)
    nodes.append(oh.make_node(
        'Reshape', [proto_name, '_rgt_w_shape'], ['_rgt_W3'],
    ))

    # Unsqueeze X -> (batch, 1, n, k)
    nodes.append(oh.make_node(
        'Unsqueeze', ['_rgt_X3', '_rgt_axis1'], ['_rgt_X4'],
    ))
    # Unsqueeze W -> (1, p, n, k)
    nodes.append(oh.make_node(
        'Unsqueeze', ['_rgt_W3', '_rgt_axis0'], ['_rgt_W4'],
    ))

    # W^T: transpose last two dims -> (1, p, k, n)
    nodes.append(oh.make_node(
        'Transpose', ['_rgt_W4'], ['_rgt_Wt'],
        perm=[0, 1, 3, 2],
    ))

    # Q1tQ2 = W^T @ X -> (batch, p, k, k)
    nodes.append(oh.make_node(
        'MatMul', ['_rgt_Wt', '_rgt_X4'], ['_rgt_Q1tQ2'],
    ))

    # proj = W @ Q1tQ2 -> (batch, p, n, k)
    nodes.append(oh.make_node(
        'MatMul', ['_rgt_W4', '_rgt_Q1tQ2'], ['_rgt_proj'],
    ))

    # tangent = X - proj -> (batch, p, n, k)
    nodes.append(oh.make_node(
        'Sub', ['_rgt_X4', '_rgt_proj'], ['_rgt_tangent4d'],
    ))

    # Reshape -> (batch, p, n*k)
    nodes.append(oh.make_node(
        'Reshape', ['_rgt_tangent4d', '_rgt_tang_shape'], ['_rgt_tangent3d'],
    ))

    # Transpose -> (p, batch, n*k)
    nodes.append(oh.make_node(
        'Transpose', ['_rgt_tangent3d'], ['tangent_gr'],
        perm=[1, 0, 2],
    ))

    extra_initializers = [
        oh.make_tensor('_rgt_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_rgt_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_rgt_x_shape', TensorProto.INT64, [3], [-1, n, k]),
        oh.make_tensor('_rgt_w_shape', TensorProto.INT64, [3], [-1, n, k]),
        oh.make_tensor('_rgt_tang_shape', TensorProto.INT64, [3],
                       [0, 0, n * k]),
    ]
    return nodes, extra_initializers, 'tangent_gr'


def _riemannian_global_omega_onnx(tangent_name, omega_name):
    r"""Global omega metric on pre-computed tangents: d^2 = ||tangent @ Omega||^2.

    Input tangent: (p, batch, d_flat), omega: (d_flat, latent_dim).
    Output: (batch, p) distances.
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # MatMul: (p, batch, d) @ (d, l) -> (p, batch, l)
    nodes.append(oh.make_node(
        'MatMul', [tangent_name, omega_name], ['_rgo_projected'],
    ))

    # projected^2
    nodes.append(oh.make_node(
        'Mul', ['_rgo_projected', '_rgo_projected'], ['_rgo_proj_sq'],
    ))

    # ReduceSum over latent axis=2 -> (p, batch)
    nodes.append(oh.make_node(
        'ReduceSum', ['_rgo_proj_sq', '_rgo_axis2'],
        ['_rgo_dist_t'], keepdims=0,
    ))

    # Transpose -> (batch, p)
    nodes.append(oh.make_node(
        'Transpose', ['_rgo_dist_t'], ['distances_rgo'],
        perm=[1, 0],
    ))

    extra_initializers = [
        oh.make_tensor('_rgo_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_rgo'


def _riemannian_local_omega_onnx(tangent_name, omegas_name):
    r"""Per-prototype omega metric: d^2_k = ||tangent_k @ Omega_k||^2.

    Input tangent: (p, batch, d_flat), omegas: (p, d_flat, latent_dim).
    Output: (batch, p) distances.
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Batched MatMul: (p, batch, d) @ (p, d, l) -> (p, batch, l)
    nodes.append(oh.make_node(
        'MatMul', [tangent_name, omegas_name], ['_rlo_projected'],
    ))

    # projected^2
    nodes.append(oh.make_node(
        'Mul', ['_rlo_projected', '_rlo_projected'], ['_rlo_proj_sq'],
    ))

    # ReduceSum over latent axis=2 -> (p, batch)
    nodes.append(oh.make_node(
        'ReduceSum', ['_rlo_proj_sq', '_rlo_axis2'],
        ['_rlo_dist_t'], keepdims=0,
    ))

    # Transpose -> (batch, p)
    nodes.append(oh.make_node(
        'Transpose', ['_rlo_dist_t'], ['distances_rlo'],
        perm=[1, 0],
    ))

    extra_initializers = [
        oh.make_tensor('_rlo_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_rlo'


def _riemannian_tangent_subspace_onnx(tangent_name, omegas_name):
    r"""Tangent subspace distance: d^2 = ||(I - Omega_k Omega_k^T) tangent_k||^2.

    Input tangent: (p, batch, d_flat), omegas: (p, d_flat, s).
    Output: (batch, p) distances.
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []

    # Step 1: Project onto subspace
    # proj = (p, batch, d) @ (p, d, s) -> (p, batch, s)
    nodes.append(oh.make_node(
        'MatMul', [tangent_name, omegas_name], ['_rts_proj'],
    ))

    # Step 2: Transpose omegas -> (p, s, d)
    nodes.append(oh.make_node(
        'Transpose', [omegas_name], ['_rts_omegas_T'],
        perm=[0, 2, 1],
    ))

    # recon = (p, batch, s) @ (p, s, d) -> (p, batch, d)
    nodes.append(oh.make_node(
        'MatMul', ['_rts_proj', '_rts_omegas_T'], ['_rts_recon'],
    ))

    # Step 3: residual = tangent - recon
    nodes.append(oh.make_node(
        'Sub', [tangent_name, '_rts_recon'], ['_rts_residual'],
    ))

    # residual^2
    nodes.append(oh.make_node(
        'Mul', ['_rts_residual', '_rts_residual'], ['_rts_res_sq'],
    ))

    # ReduceSum over features axis=2 -> (p, batch)
    nodes.append(oh.make_node(
        'ReduceSum', ['_rts_res_sq', '_rts_axis2'],
        ['_rts_dist_t'], keepdims=0,
    ))

    # Transpose -> (batch, p)
    nodes.append(oh.make_node(
        'Transpose', ['_rts_dist_t'], ['distances_rts'],
        perm=[1, 0],
    ))

    extra_initializers = [
        oh.make_tensor('_rts_axis2', TensorProto.INT64, [1], [2]),
    ]
    return nodes, extra_initializers, 'distances_rts'


def _export_riemannian_onnx(model, batch_size, opset_version, path):
    """Export a Riemannian supervised model to ONNX.

    Handles RiemannianSRNG (chordal distance on SO(n)) and
    RiemannianSMNG/SLNG/STNG (tangent-space metric on SO(n) or Grassmannian).
    """
    import onnx
    import onnx.helper as oh
    from onnx import TensorProto

    from prosemble.core.manifolds import SO, Grassmannian

    # Determine manifold and model variant
    manifold = model.manifold
    is_so = isinstance(manifold, SO)
    is_gr = isinstance(manifold, Grassmannian)
    model_name = type(model).__name__

    # Determine point shape
    if is_so:
        n = manifold.n
        n_features = n * n
        point_shape = (n, n)
    else:  # Grassmannian
        n, k = manifold.n, manifold.k
        n_features = n * k
        point_shape = (n, k)

    batch_dim = batch_size if batch_size > 0 else 'batch'
    input_shape = [batch_dim, n_features]

    all_nodes = []
    initializers = []

    # Prototypes (flattened manifold points)
    prototypes = np.asarray(model.prototypes_, dtype=np.float32)
    n_proto = prototypes.shape[0]
    initializers.append(
        oh.make_tensor(
            'prototypes', TensorProto.FLOAT,
            list(prototypes.shape), prototypes.flatten().tolist(),
        ),
    )

    # Prototype labels
    proto_labels = np.asarray(model.prototype_labels_).astype(np.int64)
    initializers.append(
        oh.make_tensor(
            'proto_labels', TensorProto.INT64,
            list(proto_labels.shape), proto_labels.flatten().tolist(),
        ),
    )

    # --- Distance computation ---
    if model_name == 'RiemannianSRNG':
        # Chordal distance: ||X - W||^2_F (works on flattened vectors directly)
        nodes, extra_inits, dist_out = _riemannian_so_chordal_onnx(
            'X', 'prototypes', n,
        )
        all_nodes.extend(nodes)
        initializers.extend(extra_inits)

    else:
        # SMNG, SLNG, STNG: compute tangent vectors first, then apply metric
        if is_so:
            tang_nodes, tang_inits, tangent_out = _riemannian_so_tangent_onnx(
                'X', 'prototypes', n,
            )
        else:  # Grassmannian
            tang_nodes, tang_inits, tangent_out = _riemannian_gr_tangent_onnx(
                'X', 'prototypes', n, k,
            )
        all_nodes.extend(tang_nodes)
        initializers.extend(tang_inits)

        # Apply metric based on model variant
        if model_name == 'RiemannianSMNG':
            # Global omega
            omega = np.asarray(model.omega_, dtype=np.float32)
            initializers.append(
                oh.make_tensor(
                    'omega', TensorProto.FLOAT,
                    list(omega.shape), omega.flatten().tolist(),
                ),
            )
            met_nodes, met_inits, dist_out = _riemannian_global_omega_onnx(
                tangent_out, 'omega',
            )

        elif model_name == 'RiemannianSLNG':
            # Per-prototype local omega
            omegas = np.asarray(model.omegas_, dtype=np.float32)
            initializers.append(
                oh.make_tensor(
                    'omegas', TensorProto.FLOAT,
                    list(omegas.shape), omegas.flatten().tolist(),
                ),
            )
            met_nodes, met_inits, dist_out = _riemannian_local_omega_onnx(
                tangent_out, 'omegas',
            )

        elif model_name == 'RiemannianSTNG':
            # Tangent subspace
            omegas = np.asarray(model.omegas_, dtype=np.float32)
            initializers.append(
                oh.make_tensor(
                    'omegas', TensorProto.FLOAT,
                    list(omegas.shape), omegas.flatten().tolist(),
                ),
            )
            met_nodes, met_inits, dist_out = _riemannian_tangent_subspace_onnx(
                tangent_out, 'omegas',
            )
        else:
            raise NotImplementedError(
                f"Unknown Riemannian model variant: {model_name}"
            )

        all_nodes.extend(met_nodes)
        initializers.extend(met_inits)

    # --- WTAC decision ---
    comp_nodes, comp_inits = _wtac_onnx_nodes(dist_out, 'proto_labels')
    all_nodes.extend(comp_nodes)
    initializers.extend(comp_inits)

    # --- Build graph ---
    X_input = oh.make_tensor_value_info('X', TensorProto.FLOAT, input_shape)
    Y_output = oh.make_tensor_value_info(
        'predictions', TensorProto.INT64, [batch_dim],
    )

    graph = oh.make_graph(
        all_nodes,
        'prosemble_riemannian_predict',
        [X_input],
        [Y_output],
        initializer=initializers,
    )

    onnx_model = oh.make_model(graph, opset_imports=[
        oh.make_opsetid('', opset_version),
    ])
    onnx_model.ir_version = 8
    onnx.checker.check_model(onnx_model)

    if path is not None:
        onnx.save(onnx_model, path)

    return onnx_model


# ---------------------------------------------------------------------------
# Competition / decision builders
# ---------------------------------------------------------------------------

def _wtac_onnx_nodes(dist_name, proto_labels_name):
    """WTAC: predictions = proto_labels[argmin(distances, axis=1)]."""
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    # ArgMin over prototypes axis
    nodes.append(oh.make_node(
        'ArgMin', [dist_name], ['_winners'],
        axis=1, keepdims=0,
    ))
    # Flatten winners for Gather
    nodes.append(oh.make_node('Cast', ['_winners'], ['_winners_i64'], to=TensorProto.INT64))
    # Gather labels
    nodes.append(oh.make_node(
        'Gather', [proto_labels_name, '_winners_i64'], ['predictions'],
        axis=0,
    ))
    return nodes, []


def _wtac_with_rejection_onnx_nodes(dist_name, proto_labels_name, threshold):
    """WTAC with reject option: confidence-based rejection.

    Computes:
        winner_label = proto_labels[argmin(distances, axis=1)]
        d_plus = min distance to same-class prototype (winner's class)
        d_minus = min distance to different-class prototype
        confidence = (d_minus - d_plus) / (d_minus + d_plus + eps)
        predictions = where(confidence >= threshold, winner_label, -1)

    Outputs two tensors: 'predictions' (INT64) and 'confidence' (FLOAT).
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    # Constants
    inits.extend([
        oh.make_tensor('_rej_eps', TensorProto.FLOAT, [], [1e-10]),
        oh.make_tensor('_rej_threshold', TensorProto.FLOAT, [], [float(threshold)]),
        oh.make_tensor('_rej_inf', TensorProto.FLOAT, [], [1e30]),
        oh.make_tensor('_rej_neg_one', TensorProto.INT64, [], [-1]),
    ])

    # Step 1: winner_idx = argmin(distances, axis=1) → (batch,)
    nodes.append(oh.make_node(
        'ArgMin', [dist_name], ['_rej_winner_idx'],
        axis=1, keepdims=0,
    ))
    nodes.append(oh.make_node(
        'Cast', ['_rej_winner_idx'], ['_rej_winner_idx_i64'],
        to=TensorProto.INT64,
    ))

    # Step 2: winner_label = proto_labels[winner_idx] → (batch,)
    nodes.append(oh.make_node(
        'Gather', [proto_labels_name, '_rej_winner_idx_i64'], ['_rej_winner_label'],
        axis=0,
    ))

    # Step 3: Build same-class mask
    # Unsqueeze winner_label to (batch, 1) for broadcasting
    nodes.append(oh.make_node(
        'Unsqueeze', ['_rej_winner_label', '_rej_unsqueeze_axis'], ['_rej_winner_label_2d'],
    ))
    inits.append(
        oh.make_tensor('_rej_unsqueeze_axis', TensorProto.INT64, [1], [1]),
    )
    # Unsqueeze proto_labels to (1, n_protos) for broadcasting
    nodes.append(oh.make_node(
        'Unsqueeze', [proto_labels_name, '_rej_unsqueeze_axis0'], ['_rej_proto_labels_2d'],
    ))
    inits.append(
        oh.make_tensor('_rej_unsqueeze_axis0', TensorProto.INT64, [1], [0]),
    )
    # same_mask = (winner_label == proto_labels) → (batch, n_protos)
    nodes.append(oh.make_node(
        'Equal', ['_rej_winner_label_2d', '_rej_proto_labels_2d'], ['_rej_same_mask'],
    ))
    # diff_mask = NOT same_mask
    nodes.append(oh.make_node(
        'Not', ['_rej_same_mask'], ['_rej_diff_mask'],
    ))

    # Step 4: d_plus = min(where(same_mask, distances, INF), axis=1)
    # Use Add with large constant to create INF-filled tensor matching dist shape
    # Where broadcasts scalars automatically against the mask shape
    nodes.append(oh.make_node(
        'Where', ['_rej_same_mask', dist_name, '_rej_inf'],
        ['_rej_dist_same'],
    ))
    nodes.append(oh.make_node(
        'ReduceMin', ['_rej_dist_same'], ['_rej_d_plus'],
        axes=[1], keepdims=0,
    ))

    # Step 5: d_minus = min(where(diff_mask, distances, INF), axis=1)
    nodes.append(oh.make_node(
        'Where', ['_rej_diff_mask', dist_name, '_rej_inf'],
        ['_rej_dist_diff'],
    ))
    nodes.append(oh.make_node(
        'ReduceMin', ['_rej_dist_diff'], ['_rej_d_minus'],
        axes=[1], keepdims=0,
    ))

    # Step 6: confidence = (d_minus - d_plus) / (d_minus + d_plus + eps)
    nodes.append(oh.make_node(
        'Sub', ['_rej_d_minus', '_rej_d_plus'], ['_rej_numerator'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_rej_d_minus', '_rej_d_plus'], ['_rej_denom_raw'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_rej_denom_raw', '_rej_eps'], ['_rej_denominator'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_rej_numerator', '_rej_denominator'], ['confidence'],
    ))

    # Step 7: predictions = where(confidence >= threshold, winner_label, -1)
    nodes.append(oh.make_node(
        'GreaterOrEqual', ['confidence', '_rej_threshold'], ['_rej_accept_mask'],
    ))
    # Where broadcasts scalar _rej_neg_one against _rej_winner_label shape
    nodes.append(oh.make_node(
        'Where', ['_rej_accept_mask', '_rej_winner_label', '_rej_neg_one'],
        ['predictions'],
    ))

    return nodes, inits


def _argmin_onnx_nodes(dist_name):
    """Unsupervised: predictions = argmin(distances, axis=1)."""
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    nodes.append(oh.make_node(
        'ArgMin', [dist_name], ['_predictions_raw'],
        axis=1, keepdims=0,
    ))
    nodes.append(oh.make_node(
        'Cast', ['_predictions_raw'], ['predictions'],
        to=TensorProto.INT64,
    ))
    return nodes, []


def _oc_hard_nearest_onnx(dist_name, model):
    """One-class hard nearest decision.

    decision_function:
        nearest_idx = argmin(distances)
        d_nearest = distances[nearest_idx]
        theta_nearest = thetas[nearest_idx]
        mu = (d - theta) / (d + theta + eps)
        score = 1 - sigmoid(beta * mu)
    predict:
        score >= 0.5 -> target_label, else non_target_label
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    beta = float(model.beta)
    target = int(model._target_label)
    non_target = int(model._non_target_label)
    thetas = np.asarray(model.thetas_).astype(np.float32)

    # Constants
    inits.extend([
        oh.make_tensor('_oc_thetas', TensorProto.FLOAT,
                        list(thetas.shape), thetas.flatten().tolist()),
        oh.make_tensor('_oc_beta', TensorProto.FLOAT, [], [beta]),
        oh.make_tensor('_oc_eps', TensorProto.FLOAT, [], [1e-10]),
        oh.make_tensor('_oc_one', TensorProto.FLOAT, [], [1.0]),
        oh.make_tensor('_oc_half', TensorProto.FLOAT, [], [0.5]),
        oh.make_tensor('_oc_target', TensorProto.INT32, [], [target]),
        oh.make_tensor('_oc_non_target', TensorProto.INT32, [], [non_target]),
        oh.make_tensor('_oc_axis1', TensorProto.INT64, [1], [1]),
    ])

    # nearest_idx = ArgMin(distances, axis=1)  -> (n,)
    nodes.append(oh.make_node(
        'ArgMin', [dist_name], ['_oc_nearest_idx'],
        axis=1, keepdims=0,
    ))

    # Cast to int64 for indexing
    nodes.append(oh.make_node(
        'Cast', ['_oc_nearest_idx'], ['_oc_nearest_i64'],
        to=TensorProto.INT64,
    ))

    # d_nearest via GatherElements: need indices as (n, 1) for axis=1
    nodes.append(oh.make_node(
        'Unsqueeze', ['_oc_nearest_i64', '_oc_axis1'], ['_oc_idx_2d'],
    ))
    nodes.append(oh.make_node(
        'GatherElements', [dist_name, '_oc_idx_2d'], ['_oc_d_2d'],
        axis=1,
    ))
    # Squeeze back to (n,)
    nodes.append(oh.make_node(
        'Squeeze', ['_oc_d_2d', '_oc_axis1'], ['_oc_d_nearest'],
    ))

    # theta_nearest = Gather(thetas, nearest_idx)  -> (n,)
    nodes.append(oh.make_node(
        'Gather', ['_oc_thetas', '_oc_nearest_i64'], ['_oc_theta_nearest'],
        axis=0,
    ))

    # mu = (d - theta) / (d + theta + eps)
    nodes.append(oh.make_node(
        'Sub', ['_oc_d_nearest', '_oc_theta_nearest'], ['_oc_num'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_oc_d_nearest', '_oc_theta_nearest'], ['_oc_den_raw'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_oc_den_raw', '_oc_eps'], ['_oc_den'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_oc_num', '_oc_den'], ['_oc_mu'],
    ))

    # score = 1 - sigmoid(beta * mu)
    nodes.append(oh.make_node(
        'Mul', ['_oc_beta', '_oc_mu'], ['_oc_beta_mu'],
    ))
    nodes.append(oh.make_node(
        'Sigmoid', ['_oc_beta_mu'], ['_oc_sig'],
    ))
    nodes.append(oh.make_node(
        'Sub', ['_oc_one', '_oc_sig'], ['_oc_score'],
    ))

    # predictions = Where(score >= 0.5, target, non_target)
    nodes.append(oh.make_node(
        'GreaterOrEqual', ['_oc_score', '_oc_half'], ['_oc_mask'],
    ))
    nodes.append(oh.make_node(
        'Where', ['_oc_mask', '_oc_target', '_oc_non_target'], ['predictions'],
    ))

    return nodes, inits


def _oc_gaussian_soft_onnx(dist_name, model):
    """One-class Gaussian soft-weighted decision.

    decision_function:
        weights = softmax(-distances / (2*sigma^2))
        mu_k = (distances - thetas) / (distances + thetas + eps)
        weighted_mu = sum(weights * mu_k, axis=1)
        score = 1 - sigmoid(beta * weighted_mu)
    predict:
        score >= 0.5 -> target_label, else non_target_label
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    beta = float(model.beta)
    sigma = float(model.sigma)
    target = int(model._target_label)
    non_target = int(model._non_target_label)
    thetas = np.asarray(model.thetas_).astype(np.float32)
    two_sigma_sq = 2.0 * sigma * sigma

    inits.extend([
        oh.make_tensor('_gs_thetas', TensorProto.FLOAT,
                        list(thetas.shape), thetas.flatten().tolist()),
        oh.make_tensor('_gs_beta', TensorProto.FLOAT, [], [beta]),
        oh.make_tensor('_gs_eps', TensorProto.FLOAT, [], [1e-10]),
        oh.make_tensor('_gs_one', TensorProto.FLOAT, [], [1.0]),
        oh.make_tensor('_gs_half', TensorProto.FLOAT, [], [0.5]),
        oh.make_tensor('_gs_neg_inv_2s2', TensorProto.FLOAT, [],
                        [-1.0 / two_sigma_sq]),
        oh.make_tensor('_gs_target', TensorProto.INT32, [], [target]),
        oh.make_tensor('_gs_non_target', TensorProto.INT32, [], [non_target]),
        oh.make_tensor('_gs_axis1', TensorProto.INT64, [1], [1]),
    ])

    # Gaussian weights = softmax(-d / (2*sigma^2), axis=1)
    nodes.append(oh.make_node(
        'Mul', [dist_name, '_gs_neg_inv_2s2'], ['_gs_logits'],
    ))
    nodes.append(oh.make_node(
        'Softmax', ['_gs_logits'], ['_gs_weights'],
        axis=1,
    ))

    # Per-prototype mu_k = (d - theta) / (d + theta + eps)
    # thetas broadcast: (K,) -> (1, K) via Unsqueeze
    nodes.append(oh.make_node(
        'Unsqueeze', ['_gs_thetas', '_gs_axis1'], ['_gs_thetas_r1'],
    ))
    # Reshape to row: remove batch dim from Unsqueeze result
    # Actually Unsqueeze at axis=0 gives (1, K) — let's use axis=0
    # We need (1, K) for broadcasting with (n, K) distances
    # Fix: Unsqueeze thetas at axis=0 instead
    inits.append(
        oh.make_tensor('_gs_axis0', TensorProto.INT64, [1], [0]),
    )
    # Remove the axis=1 unsqueeze for thetas, use axis=0
    nodes.pop()  # remove the wrong Unsqueeze
    nodes.append(oh.make_node(
        'Unsqueeze', ['_gs_thetas', '_gs_axis0'], ['_gs_thetas_2d'],
    ))

    nodes.append(oh.make_node(
        'Sub', [dist_name, '_gs_thetas_2d'], ['_gs_num'],
    ))
    nodes.append(oh.make_node(
        'Add', [dist_name, '_gs_thetas_2d'], ['_gs_den_raw'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_gs_den_raw', '_gs_eps'], ['_gs_den'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_gs_num', '_gs_den'], ['_gs_mu_k'],
    ))

    # weighted_mu = sum(weights * mu_k, axis=1)
    nodes.append(oh.make_node(
        'Mul', ['_gs_weights', '_gs_mu_k'], ['_gs_w_mu'],
    ))
    nodes.append(oh.make_node(
        'ReduceSum', ['_gs_w_mu', '_gs_axis1'], ['_gs_weighted_mu'],
        keepdims=0,
    ))

    # score = 1 - sigmoid(beta * weighted_mu)
    nodes.append(oh.make_node(
        'Mul', ['_gs_beta', '_gs_weighted_mu'], ['_gs_beta_mu'],
    ))
    nodes.append(oh.make_node(
        'Sigmoid', ['_gs_beta_mu'], ['_gs_sig'],
    ))
    nodes.append(oh.make_node(
        'Sub', ['_gs_one', '_gs_sig'], ['_gs_score'],
    ))

    # predictions
    nodes.append(oh.make_node(
        'GreaterOrEqual', ['_gs_score', '_gs_half'], ['_gs_mask'],
    ))
    nodes.append(oh.make_node(
        'Where', ['_gs_mask', '_gs_target', '_gs_non_target'], ['predictions'],
    ))

    return nodes, inits


def _oc_gaussian_ng_onnx(dist_name, model, n_proto):
    """One-class Gaussian+NG soft-weighted decision.

    Same as Gaussian Soft, but weights = Gaussian × NG rank weights
    (normalized), using the converged gamma_.

    NG rank weights:
        order = argsort(distances)
        ranks = argsort(order)  # inverse permutation
        h = exp(-ranks / gamma)
        h_norm = h / sum(h)
    Combined: combined = gauss * h_norm (re-normalized)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    beta = float(model.beta)
    sigma = float(model.sigma)
    gamma = float(model.gamma_)
    target = int(model._target_label)
    non_target = int(model._non_target_label)
    thetas = np.asarray(model.thetas_).astype(np.float32)
    two_sigma_sq = 2.0 * sigma * sigma
    K = n_proto

    # Range [0, 1, ..., K-1] for ScatterElements
    range_K = np.arange(K, dtype=np.int64)

    inits.extend([
        oh.make_tensor('_gn_thetas', TensorProto.FLOAT,
                        list(thetas.shape), thetas.flatten().tolist()),
        oh.make_tensor('_gn_beta', TensorProto.FLOAT, [], [beta]),
        oh.make_tensor('_gn_eps', TensorProto.FLOAT, [], [1e-10]),
        oh.make_tensor('_gn_one', TensorProto.FLOAT, [], [1.0]),
        oh.make_tensor('_gn_half', TensorProto.FLOAT, [], [0.5]),
        oh.make_tensor('_gn_neg_inv_2s2', TensorProto.FLOAT, [],
                        [-1.0 / two_sigma_sq]),
        oh.make_tensor('_gn_neg_inv_gamma', TensorProto.FLOAT, [],
                        [-1.0 / (gamma + 1e-10)]),
        oh.make_tensor('_gn_target', TensorProto.INT32, [], [target]),
        oh.make_tensor('_gn_non_target', TensorProto.INT32, [], [non_target]),
        oh.make_tensor('_gn_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_gn_axis1', TensorProto.INT64, [1], [1]),
        oh.make_tensor('_gn_K', TensorProto.INT64, [1], [K]),
        oh.make_tensor('_gn_range_K', TensorProto.INT64, [K],
                        range_K.tolist()),
    ])

    # --- Gaussian weights ---
    nodes.append(oh.make_node(
        'Mul', [dist_name, '_gn_neg_inv_2s2'], ['_gn_logits'],
    ))
    nodes.append(oh.make_node(
        'Softmax', ['_gn_logits'], ['_gn_gauss'],
        axis=1,
    ))

    # --- NG rank weights ---
    # TopK to get argsort (ascending = smallest first)
    nodes.append(oh.make_node(
        'TopK', [dist_name, '_gn_K'], ['_gn_sorted_vals', '_gn_sorted_idx'],
        largest=0, sorted=1,
    ))

    # Compute ranks via ScatterElements:
    # ranks[i, sorted_idx[i,j]] = j
    # Create (n, K) of zeros, then scatter range_K into it
    # First, get the shape of distances for creating zeros
    nodes.append(oh.make_node(
        'Shape', [dist_name], ['_gn_dist_shape'],
    ))
    nodes.append(oh.make_node(
        'ConstantOfShape', ['_gn_dist_shape'], ['_gn_zeros'],
        value=oh.make_tensor('', TensorProto.INT64, [1], [0]),
    ))

    # Expand range_K to (n, K): first Unsqueeze to (1, K), then Expand
    nodes.append(oh.make_node(
        'Unsqueeze', ['_gn_range_K', '_gn_axis0'], ['_gn_range_2d'],
    ))
    nodes.append(oh.make_node(
        'Expand', ['_gn_range_2d', '_gn_dist_shape'], ['_gn_range_expanded'],
    ))

    # ScatterElements: zeros[i, sorted_idx[i,j]] = range_expanded[i,j] = j
    nodes.append(oh.make_node(
        'ScatterElements', ['_gn_zeros', '_gn_sorted_idx', '_gn_range_expanded'],
        ['_gn_ranks_i64'],
        axis=1,
    ))

    # Cast ranks to float
    nodes.append(oh.make_node(
        'Cast', ['_gn_ranks_i64'], ['_gn_ranks'],
        to=TensorProto.FLOAT,
    ))

    # h = exp(-ranks / gamma) = exp(ranks * neg_inv_gamma)
    nodes.append(oh.make_node(
        'Mul', ['_gn_ranks', '_gn_neg_inv_gamma'], ['_gn_h_logits'],
    ))
    nodes.append(oh.make_node(
        'Exp', ['_gn_h_logits'], ['_gn_h'],
    ))

    # h_norm = h / sum(h, axis=1, keepdims=1)
    nodes.append(oh.make_node(
        'ReduceSum', ['_gn_h', '_gn_axis1'], ['_gn_h_sum'],
        keepdims=1,
    ))
    nodes.append(oh.make_node(
        'Add', ['_gn_h_sum', '_gn_eps'], ['_gn_h_sum_eps'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_gn_h', '_gn_h_sum_eps'], ['_gn_h_norm'],
    ))

    # --- Combined weights ---
    nodes.append(oh.make_node(
        'Mul', ['_gn_gauss', '_gn_h_norm'], ['_gn_combined_raw'],
    ))
    nodes.append(oh.make_node(
        'ReduceSum', ['_gn_combined_raw', '_gn_axis1'], ['_gn_comb_sum'],
        keepdims=1,
    ))
    nodes.append(oh.make_node(
        'Add', ['_gn_comb_sum', '_gn_eps'], ['_gn_comb_sum_eps'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_gn_combined_raw', '_gn_comb_sum_eps'], ['_gn_combined'],
    ))

    # --- Per-prototype mu_k ---
    nodes.append(oh.make_node(
        'Unsqueeze', ['_gn_thetas', '_gn_axis0'], ['_gn_thetas_2d'],
    ))
    nodes.append(oh.make_node(
        'Sub', [dist_name, '_gn_thetas_2d'], ['_gn_num'],
    ))
    nodes.append(oh.make_node(
        'Add', [dist_name, '_gn_thetas_2d'], ['_gn_den_raw'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_gn_den_raw', '_gn_eps'], ['_gn_den'],
    ))
    nodes.append(oh.make_node(
        'Div', ['_gn_num', '_gn_den'], ['_gn_mu_k'],
    ))

    # weighted_mu = sum(combined * mu_k, axis=1)
    nodes.append(oh.make_node(
        'Mul', ['_gn_combined', '_gn_mu_k'], ['_gn_w_mu'],
    ))
    nodes.append(oh.make_node(
        'ReduceSum', ['_gn_w_mu', '_gn_axis1'], ['_gn_weighted_mu'],
        keepdims=0,
    ))

    # score = 1 - sigmoid(beta * weighted_mu)
    nodes.append(oh.make_node(
        'Mul', ['_gn_beta', '_gn_weighted_mu'], ['_gn_beta_mu'],
    ))
    nodes.append(oh.make_node(
        'Sigmoid', ['_gn_beta_mu'], ['_gn_sig'],
    ))
    nodes.append(oh.make_node(
        'Sub', ['_gn_one', '_gn_sig'], ['_gn_score'],
    ))

    # predictions
    nodes.append(oh.make_node(
        'GreaterOrEqual', ['_gn_score', '_gn_half'], ['_gn_mask'],
    ))
    nodes.append(oh.make_node(
        'Where', ['_gn_mask', '_gn_target', '_gn_non_target'], ['predictions'],
    ))

    return nodes, inits


def _svqocc_onnx(dist_name, model, n_proto):
    """SVQ-OCC decision: response probability × Heaviside sigmoid.

    decision_function:
        if gaussian: p_k = softmax(-gamma_resp * distances)
        elif student_t: p_k = normalize((1 + d/nu)^(-(nu+1)/2))
        else: p_k = 1/K
        heaviside = sigmoid((thetas - distances) / sigma)
        score = clip(sum(p_k * heaviside), 0, 1)
    predict:
        score >= 0.5 -> target_label, else non_target_label
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    sigma = float(model.sigma)
    target = int(model._target_label)
    non_target = int(model._non_target_label)
    thetas = np.asarray(model.thetas_).astype(np.float32)
    response_type = model.response_type
    K = n_proto

    inits.extend([
        oh.make_tensor('_sv_thetas', TensorProto.FLOAT,
                        list(thetas.shape), thetas.flatten().tolist()),
        oh.make_tensor('_sv_inv_sigma', TensorProto.FLOAT, [],
                        [1.0 / (sigma + 1e-10)]),
        oh.make_tensor('_sv_zero', TensorProto.FLOAT, [], [0.0]),
        oh.make_tensor('_sv_one', TensorProto.FLOAT, [], [1.0]),
        oh.make_tensor('_sv_half', TensorProto.FLOAT, [], [0.5]),
        oh.make_tensor('_sv_target', TensorProto.INT32, [], [target]),
        oh.make_tensor('_sv_non_target', TensorProto.INT32, [], [non_target]),
        oh.make_tensor('_sv_axis0', TensorProto.INT64, [1], [0]),
        oh.make_tensor('_sv_axis1', TensorProto.INT64, [1], [1]),
    ])

    # --- Response probability p_k ---
    if response_type == 'gaussian':
        gamma_resp = float(model.gamma_resp)
        inits.append(
            oh.make_tensor('_sv_neg_gamma', TensorProto.FLOAT, [],
                            [-gamma_resp]),
        )
        nodes.append(oh.make_node(
            'Mul', [dist_name, '_sv_neg_gamma'], ['_sv_resp_logits'],
        ))
        nodes.append(oh.make_node(
            'Softmax', ['_sv_resp_logits'], ['_sv_p_k'],
            axis=1,
        ))
    elif response_type == 'student_t':
        nu = float(model.nu)
        exponent = -(nu + 1.0) / 2.0
        inits.extend([
            oh.make_tensor('_sv_inv_nu', TensorProto.FLOAT, [], [1.0 / nu]),
            oh.make_tensor('_sv_exponent', TensorProto.FLOAT, [], [exponent]),
            oh.make_tensor('_sv_eps', TensorProto.FLOAT, [], [1e-10]),
        ])
        # (1 + d/nu)
        nodes.append(oh.make_node(
            'Mul', [dist_name, '_sv_inv_nu'], ['_sv_d_over_nu'],
        ))
        nodes.append(oh.make_node(
            'Add', ['_sv_one', '_sv_d_over_nu'], ['_sv_base'],
        ))
        # base^exponent
        nodes.append(oh.make_node(
            'Pow', ['_sv_base', '_sv_exponent'], ['_sv_p_unnorm'],
        ))
        # normalize
        nodes.append(oh.make_node(
            'ReduceSum', ['_sv_p_unnorm', '_sv_axis1'], ['_sv_p_sum'],
            keepdims=1,
        ))
        nodes.append(oh.make_node(
            'Add', ['_sv_p_sum', '_sv_eps'], ['_sv_p_sum_eps'],
        ))
        nodes.append(oh.make_node(
            'Div', ['_sv_p_unnorm', '_sv_p_sum_eps'], ['_sv_p_k'],
        ))
    else:  # uniform
        inv_K = 1.0 / K
        inits.append(
            oh.make_tensor('_sv_inv_K', TensorProto.FLOAT, [], [inv_K]),
        )
        # Broadcast scalar to (n, K) via Mul with ones-like distances
        # Use: p_k = distances * 0 + inv_K (broadcasts correctly)
        nodes.append(oh.make_node(
            'Mul', [dist_name, '_sv_zero'], ['_sv_zeros_nk'],
        ))
        nodes.append(oh.make_node(
            'Add', ['_sv_zeros_nk', '_sv_inv_K'], ['_sv_p_k'],
        ))

    # --- Heaviside sigmoid ---
    # heaviside = sigmoid((thetas - distances) / sigma)
    nodes.append(oh.make_node(
        'Unsqueeze', ['_sv_thetas', '_sv_axis0'], ['_sv_thetas_2d'],
    ))
    nodes.append(oh.make_node(
        'Sub', ['_sv_thetas_2d', dist_name], ['_sv_theta_minus_d'],
    ))
    nodes.append(oh.make_node(
        'Mul', ['_sv_theta_minus_d', '_sv_inv_sigma'], ['_sv_heav_input'],
    ))
    nodes.append(oh.make_node(
        'Sigmoid', ['_sv_heav_input'], ['_sv_heaviside'],
    ))

    # --- Responsibility = p_k * heaviside ---
    nodes.append(oh.make_node(
        'Mul', ['_sv_p_k', '_sv_heaviside'], ['_sv_responsibility'],
    ))

    # --- Score = clip(sum(responsibility, axis=1), 0, 1) ---
    nodes.append(oh.make_node(
        'ReduceSum', ['_sv_responsibility', '_sv_axis1'], ['_sv_raw_score'],
        keepdims=0,
    ))
    nodes.append(oh.make_node(
        'Clip', ['_sv_raw_score', '_sv_zero', '_sv_one'], ['_sv_score'],
    ))

    # --- Predictions ---
    nodes.append(oh.make_node(
        'GreaterOrEqual', ['_sv_score', '_sv_half'], ['_sv_mask'],
    ))
    nodes.append(oh.make_node(
        'Where', ['_sv_mask', '_sv_target', '_sv_non_target'], ['predictions'],
    ))

    return nodes, inits


def _cbc_onnx(dist_name, model):
    """CBC reasoning decision.

    decision:
        detections = exp(-distances / (2 * sigma^2))
        A = clip(reasonings[:, :, 0], 0, 1)
        B = clip(reasonings[:, :, 1], 0, 1)
        pk = A, nk = (1 - A) * B
        probs = (detections @ (pk - nk) + sum(nk)) / (sum(pk + nk) + eps)
        predictions = argmax(probs, axis=1)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    sigma_sq = float(model.sigma) ** 2

    # Pre-compute CBC reasoning constants
    reasonings = np.asarray(model.reasonings_, dtype=np.float32)
    A = np.clip(reasonings[:, :, 0], 0, 1)
    B = np.clip(reasonings[:, :, 1], 0, 1)
    pk = A
    nk = (1.0 - A) * B
    pk_minus_nk = (pk - nk).astype(np.float32)      # (n_comp, n_classes)
    sum_nk = np.sum(nk, axis=0).astype(np.float32)   # (n_classes,)
    denom = (np.sum(pk + nk, axis=0) + 1e-8).astype(np.float32)

    inits.extend([
        oh.make_tensor('_cbc_scale', TensorProto.FLOAT, [],
                        [-1.0 / (2.0 * sigma_sq)]),
        oh.make_tensor('_cbc_pk_nk', TensorProto.FLOAT,
                        list(pk_minus_nk.shape),
                        pk_minus_nk.flatten().tolist()),
        oh.make_tensor('_cbc_sum_nk', TensorProto.FLOAT,
                        list(sum_nk.shape), sum_nk.flatten().tolist()),
        oh.make_tensor('_cbc_denom', TensorProto.FLOAT,
                        list(denom.shape), denom.flatten().tolist()),
    ])

    # Gaussian similarity: detections = exp(-d / (2*sigma^2))
    nodes.append(oh.make_node(
        'Mul', [dist_name, '_cbc_scale'], ['_cbc_logits'],
    ))
    nodes.append(oh.make_node(
        'Exp', ['_cbc_logits'], ['_cbc_detections'],
    ))

    # numerator = detections @ (pk - nk) + sum(nk)
    nodes.append(oh.make_node(
        'MatMul', ['_cbc_detections', '_cbc_pk_nk'], ['_cbc_matmul'],
    ))
    nodes.append(oh.make_node(
        'Add', ['_cbc_matmul', '_cbc_sum_nk'], ['_cbc_numerator'],
    ))

    # probs = numerator / denominator
    nodes.append(oh.make_node(
        'Div', ['_cbc_numerator', '_cbc_denom'], ['_cbc_probs'],
    ))

    # predictions = argmax(probs, axis=1)
    nodes.append(oh.make_node(
        'ArgMax', ['_cbc_probs'], ['_cbc_preds_raw'],
        axis=1, keepdims=0,
    ))
    nodes.append(oh.make_node(
        'Cast', ['_cbc_preds_raw'], ['predictions'],
        to=TensorProto.INT64,
    ))

    return nodes, inits


def _plvq_onnx(dist_name, model):
    """PLVQ Gaussian mixture soft assignment decision.

    decision:
        logits = -distances / (2 * sigma^2)
        probs = softmax(logits, axis=1)
        class_probs = probs @ class_mask   (aggregate per class)
        predictions = argmax(class_probs, axis=1)
    """
    import onnx.helper as oh
    from onnx import TensorProto

    nodes = []
    inits = []

    sigma_sq = float(model.sigma) ** 2
    proto_labels = np.asarray(model.prototype_labels_, dtype=np.int64)
    n_proto = len(proto_labels)
    n_classes = int(model.n_classes_)

    # Pre-compute class mask: M[j, c] = 1 if prototype j belongs to class c
    class_mask = np.zeros((n_proto, n_classes), dtype=np.float32)
    for j, c in enumerate(proto_labels):
        class_mask[j, int(c)] = 1.0

    inits.extend([
        oh.make_tensor('_plvq_scale', TensorProto.FLOAT, [],
                        [-1.0 / (2.0 * sigma_sq)]),
        oh.make_tensor('_plvq_mask', TensorProto.FLOAT,
                        list(class_mask.shape),
                        class_mask.flatten().tolist()),
    ])

    # logits = -d / (2*sigma^2)
    nodes.append(oh.make_node(
        'Mul', [dist_name, '_plvq_scale'], ['_plvq_logits'],
    ))

    # probs = softmax(logits, axis=1)
    nodes.append(oh.make_node(
        'Softmax', ['_plvq_logits'], ['_plvq_probs'],
        axis=1,
    ))

    # class_probs = probs @ class_mask  -> (batch, n_classes)
    nodes.append(oh.make_node(
        'MatMul', ['_plvq_probs', '_plvq_mask'], ['_plvq_class_probs'],
    ))

    # predictions = argmax(class_probs, axis=1)
    nodes.append(oh.make_node(
        'ArgMax', ['_plvq_class_probs'], ['_plvq_preds_raw'],
        axis=1, keepdims=0,
    ))
    nodes.append(oh.make_node(
        'Cast', ['_plvq_preds_raw'], ['predictions'],
        to=TensorProto.INT64,
    ))

    return nodes, inits


# ---------------------------------------------------------------------------
# Model type and distance identification
# ---------------------------------------------------------------------------

def _identify_model_type(model):
    """Identify the model type and distance function for ONNX export.

    Returns
    -------
    model_type : str
        One of 'supervised', 'unsupervised', 'oc_hard_nearest',
        'oc_gaussian_soft', 'oc_gaussian_ng', 'svqocc', 'cbc', 'plvq'.
    dist_type : str
        One of 'squared_euclidean', 'euclidean', 'manhattan', 'omega',
        'relevance', 'local_omega', 'tangent', 'kernel_per_proto',
        'kernel_relevance', 'kernel_exponential'.
    """
    # --- Determine model type ---

    has_encoder = (
        hasattr(model, 'backbone_params_')
        and model.backbone_params_ is not None
    )

    # Encoder models: detect first (before OC/supervised checks)
    if has_encoder:
        # ImageCBC: encoder + CBC reasoning
        if (hasattr(model, 'components_') and model.components_ is not None
                and hasattr(model, 'reasonings_')
                and model.reasonings_ is not None):
            model_type = 'cbc'
        # PLVQ: encoder + Gaussian mixture (unique: has loss_type)
        elif hasattr(model, 'loss_type'):
            model_type = 'plvq'
        # Siamese*, Image*, LVQMLN: encoder + WTAC
        else:
            model_type = 'supervised'
        dist_type = _identify_distance_fn(model, model_type)
        return model_type, dist_type

    # CBC (non-encoder): has reasonings_ but no backbone_params_
    if (hasattr(model, 'reasonings_') and model.reasonings_ is not None
            and hasattr(model, 'components_')
            and model.components_ is not None):
        dist_type = _identify_distance_fn(model, 'cbc')
        return 'cbc', dist_type

    # SVQ-OCC: unique response_type attribute
    is_svqocc = hasattr(model, 'response_type') and hasattr(model, 'gamma_resp')

    # One-class models: all have thetas_ (learned thresholds)
    has_thetas = (
        hasattr(model, 'thetas_') and model.thetas_ is not None
    )

    if is_svqocc:
        model_type = 'svqocc'
    elif has_thetas:
        # Distinguish OC decision patterns by sigma and gamma_
        has_sigma = hasattr(model, 'sigma') and not is_svqocc
        has_gamma = hasattr(model, 'gamma_') and model.gamma_ is not None
        if has_sigma and has_gamma:
            model_type = 'oc_gaussian_ng'
        elif has_sigma:
            model_type = 'oc_gaussian_soft'
        else:
            model_type = 'oc_hard_nearest'
    else:
        # Supervised or unsupervised (existing logic)
        has_labels = (
            hasattr(model, 'prototype_labels_')
            and model.prototype_labels_ is not None
        )
        model_type = 'supervised' if has_labels else 'unsupervised'

    # --- Determine distance type ---
    dist_type = _identify_distance_fn(model, model_type)

    return model_type, dist_type


def _identify_distance_fn(model, model_type='supervised') -> str:
    """Identify which distance function a model uses.

    For DK models, distance is detected from kernel-specific attributes
    (sigmas_, omega_hat_).  For OC and SVQ-OCC models, distance is
    detected from learned metric parameters (omega_, omegas_,
    relevances_).  For supervised and unsupervised models, distance is
    detected from the distance_fn attribute.
    """
    # --- Attribute-based detection (OC, SVQ-OCC, and supervised metric models) ---

    # Kernel distance detection (DK models) — must come before Euclidean
    # attribute checks to prevent misidentification
    has_sigmas = hasattr(model, 'sigmas_') and model.sigmas_ is not None
    has_omega_hat = hasattr(model, 'omega_hat_') and model.omega_hat_ is not None

    if has_omega_hat:
        return 'kernel_exponential'

    if has_sigmas:
        has_relevances = (hasattr(model, 'relevances_')
                          and model.relevances_ is not None)
        if has_relevances:
            return 'kernel_relevance'
        return 'kernel_per_proto'

    # Tangent vs local omega: tangent models have subspace_dim
    has_omegas = hasattr(model, 'omegas_') and model.omegas_ is not None
    is_tangent = has_omegas and hasattr(model, 'subspace_dim')

    if is_tangent:
        return 'tangent'
    if has_omegas:
        return 'local_omega'

    # Global omega
    if hasattr(model, 'omega_') and model.omega_ is not None:
        return 'omega'

    # Relevance-weighted (GRLVQ family)
    if hasattr(model, 'relevances_') and model.relevances_ is not None:
        return 'relevance'

    # For OC/SVQ-OCC/CBC/PLVQ models without metric params,
    # default to squared_euclidean
    if model_type in ('oc_hard_nearest', 'oc_gaussian_soft',
                       'oc_gaussian_ng', 'svqocc', 'cbc', 'plvq'):
        return 'squared_euclidean'

    # Encoder models without explicit distance_fn: default to squared_euclidean
    if (hasattr(model, 'backbone_params_')
            and model.backbone_params_ is not None):
        return 'squared_euclidean'

    # --- Function-based detection (supervised/unsupervised models) ---
    from prosemble.core.distance import (
        squared_euclidean_distance_matrix,
        euclidean_distance_matrix,
        manhattan_distance_matrix,
    )

    fn = model.distance_fn

    _known = {
        squared_euclidean_distance_matrix: 'squared_euclidean',
        euclidean_distance_matrix: 'euclidean',
        manhattan_distance_matrix: 'manhattan',
    }

    for known_fn, name in _known.items():
        if fn is known_fn:
            return name

    if hasattr(fn, 'func'):
        from prosemble.core.distance import omega_distance_matrix
        if fn.func is omega_distance_matrix:
            return 'omega'

    fn_name = getattr(fn, '__name__', '') or getattr(fn, '__wrapped__', '')
    if 'squared_euclidean' in str(fn_name):
        return 'squared_euclidean'
    if 'euclidean' in str(fn_name) and 'squared' not in str(fn_name):
        return 'euclidean'
    if 'manhattan' in str(fn_name):
        return 'manhattan'
    if 'omega_distance_matrix' in str(fn_name):
        return 'omega'

    raise NotImplementedError(
        f"ONNX export is not supported for distance function "
        f"'{fn_name or fn}'. Supported: squared_euclidean, euclidean, "
        f"manhattan, omega, relevance, local_omega, tangent, "
        f"kernel_per_proto, kernel_relevance, kernel_exponential."
    )


def _check_model_exportable(model):
    """Check if a model can be exported to ONNX.

    Raises NotImplementedError for models that cannot be converted.
    """
    # Riemannian models: check manifold-specific exportability
    from prosemble.models.riemannian_srng import RiemannianSRNG
    if isinstance(model, RiemannianSRNG):
        from prosemble.core.manifolds import SO, Grassmannian, SPD
        if isinstance(model.manifold, SPD):
            raise NotImplementedError(
                "ONNX export is not supported for Riemannian models with "
                "SPD(n) manifold. Eigendecomposition (eigh) has no ONNX "
                "equivalent."
            )
        model_name = type(model).__name__
        if (isinstance(model.manifold, Grassmannian)
                and model_name == 'RiemannianSRNG'):
            raise NotImplementedError(
                "ONNX export is not supported for RiemannianSRNG with "
                "Grassmannian manifold. SVD-based geodesic distance has "
                "no ONNX equivalent."
            )
        # SO(n) for all 4 models, and Grassmannian for SMNG/SLNG/STNG are OK
        return

    try:
        from prosemble.models.riemannian_neural_gas import RiemannianNeuralGas
        if isinstance(model, RiemannianNeuralGas):
            raise NotImplementedError(
                "ONNX export is not supported for RiemannianNeuralGas. "
                "Manifold operations have no ONNX equivalent."
            )
    except ImportError:
        pass

    # Kernel fuzzy clustering models
    if hasattr(model, 'sigma') and hasattr(model, 'centroids_'):
        raise NotImplementedError(
            "ONNX export is not supported for kernel fuzzy clustering "
            "models. Kernel distance (Gaussian kernel) has no standard "
            "ONNX equivalent."
        )


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def export_onnx( model: Any, batch_size: int = 1, opset_version: int = 17, path: str | None = None, reject_threshold: float | None = None, ): """Export a fitted model's predict function to ONNX format. Builds an ONNX graph that reproduces the model's ``predict()`` output. Supports supervised (WTAC), unsupervised (ArgMin), one-class (threshold-based), SVQ-OCC (response model), CBC (reasoning matrices), PLVQ (Gaussian mixture), encoder models (MLP/CNN backbone), Differentiating Kernel models (Gaussian, relevance-weighted, and exponential kernels), and Riemannian models on SO(n)/Grassmannian manifolds (88 of 102 models total). Parameters ---------- model : SupervisedPrototypeModel or UnsupervisedPrototypeModel A fitted prosemble model. batch_size : int Fixed batch dimension for the input. Use ``-1`` for dynamic batch size (ONNX symbolic dimension). opset_version : int ONNX opset version. Default: 17. path : str, optional If provided, save the ONNX model to this file path. reject_threshold : float, optional If provided, enables reject option for supervised models. Samples with confidence below this threshold are assigned prediction -1 (rejected). The exported model produces two outputs: ``predictions`` (INT64, with -1 for rejected) and ``confidence`` (FLOAT, in [-1, 1]). Only supported for supervised models (WTAC decision). Returns ------- onnx.ModelProto The exported ONNX model. Raises ------ NotImplementedError If the model's distance function or decision pattern is not supported. ImportError If the ``onnx`` package is not installed. ValueError If ``reject_threshold`` is used with a non-supervised model. """ _check_onnx_installed() import onnx import onnx.helper as oh from onnx import TensorProto model._check_fitted() _check_model_exportable(model) # Riemannian models: use dedicated export path from prosemble.models.riemannian_srng import RiemannianSRNG if isinstance(model, RiemannianSRNG): return _export_riemannian_onnx(model, batch_size, opset_version, path) # Identify model type and distance model_type, dist_type = _identify_model_type(model) # Check for encoder (MLP/CNN backbone) has_encoder = ( hasattr(model, 'backbone_params_') and model.backbone_params_ is not None ) is_cnn = has_encoder and isinstance(model.backbone_params_, dict) # Input shape batch_dim = batch_size if batch_size > 0 else 'batch' all_nodes = [] initializers = [] if has_encoder: # --- Encoder model path --- # Determine input dimension if is_cnn: n_input_features = int(np.prod(model.input_shape)) else: n_input_features = int( np.asarray(model.backbone_params_[0][0]).shape[0] ) input_shape = [batch_dim, n_input_features] # Build encoder ONNX nodes if is_cnn: enc_nodes, enc_inits, enc_out = _cnn_encoder_onnx( 'X', model.backbone_params_, model.input_shape, model.activation, ) else: enc_nodes, enc_inits, enc_out = _mlp_encoder_onnx( 'X', model.backbone_params_, model.activation, ) all_nodes.extend(enc_nodes) initializers.extend(enc_inits) # Pre-compute latent prototypes/components model_name = type(model).__name__ proto_need_encoding = ( model_name.startswith('Siamese') or model_name.startswith('Image') ) if model_type == 'cbc': # CBC models use components_ comps = np.asarray(model.components_, dtype=np.float32) if proto_need_encoding: if is_cnn: comps_img = comps.reshape(-1, *model.input_shape) latent_protos = _cnn_forward_np( model.backbone_params_, comps_img, model.activation, ) else: latent_protos = _mlp_forward_np( model.backbone_params_, comps, model.activation, ) else: latent_protos = comps elif proto_need_encoding: protos = np.asarray(model.prototypes_, dtype=np.float32) if is_cnn: protos_img = protos.reshape(-1, *model.input_shape) latent_protos = _cnn_forward_np( model.backbone_params_, protos_img, model.activation, ) else: latent_protos = _mlp_forward_np( model.backbone_params_, protos, model.activation, ) else: # LVQMLN/PLVQ: prototypes already in latent space latent_protos = np.asarray(model.prototypes_, dtype=np.float32) latent_protos = latent_protos.astype(np.float32) n_proto = latent_protos.shape[0] X_for_distance = enc_out else: # --- Non-encoder model path --- prototypes = np.asarray(model.prototypes_, dtype=np.float32) n_proto, n_features = prototypes.shape input_shape = [batch_dim, n_features] latent_protos = prototypes X_for_distance = 'X' # --- Prototypes initializer --- initializers.append( oh.make_tensor( 'prototypes', TensorProto.FLOAT, list(latent_protos.shape), latent_protos.flatten().tolist(), ), ) # Prototype labels for supervised models has_labels = ( model_type == 'supervised' and hasattr(model, 'prototype_labels_') and model.prototype_labels_ is not None ) if has_labels: proto_labels = np.asarray(model.prototype_labels_).astype(np.int64) initializers.append( oh.make_tensor( 'proto_labels', TensorProto.INT64, list(proto_labels.shape), proto_labels.flatten().tolist(), ), ) # --- Distance nodes --- if dist_type == 'squared_euclidean': nodes, extra_inits, dist_out = _squared_euclidean_onnx( None, X_for_distance, 'prototypes', ) elif dist_type == 'euclidean': nodes, extra_inits, dist_out = _euclidean_onnx( None, X_for_distance, 'prototypes', ) elif dist_type == 'manhattan': nodes, extra_inits, dist_out = _manhattan_onnx( None, X_for_distance, 'prototypes', ) elif dist_type == 'omega': omega = np.asarray(model.omega_, dtype=np.float32) initializers.append( oh.make_tensor( 'omega', TensorProto.FLOAT, list(omega.shape), omega.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _omega_onnx( None, X_for_distance, 'prototypes', 'omega', ) elif dist_type == 'relevance': relevances = np.asarray(model.relevances_).astype(np.float32) initializers.append( oh.make_tensor( 'relevances', TensorProto.FLOAT, list(relevances.shape), relevances.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _relevance_weighted_onnx( None, X_for_distance, 'prototypes', 'relevances', ) elif dist_type == 'local_omega': omegas = np.asarray(model.omegas_).astype(np.float32) initializers.append( oh.make_tensor( 'omegas', TensorProto.FLOAT, list(omegas.shape), omegas.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _local_omega_onnx( None, X_for_distance, 'prototypes', 'omegas', ) elif dist_type == 'tangent': omegas = np.asarray(model.omegas_).astype(np.float32) initializers.append( oh.make_tensor( 'omegas', TensorProto.FLOAT, list(omegas.shape), omegas.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _tangent_onnx( None, X_for_distance, 'prototypes', 'omegas', ) elif dist_type == 'kernel_per_proto': sigmas = np.asarray(model.sigmas_).astype(np.float32) sigma_min = getattr(model, 'sigma_min', 1e-3) sigmas = np.maximum(sigmas, sigma_min) neg_inv_2sigma_sq = (-1.0 / (2.0 * sigmas ** 2)).astype(np.float32) initializers.append( oh.make_tensor( 'neg_inv_2sigma_sq', TensorProto.FLOAT, list(neg_inv_2sigma_sq.shape), neg_inv_2sigma_sq.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _kernel_per_proto_onnx( None, X_for_distance, 'prototypes', 'neg_inv_2sigma_sq', ) elif dist_type == 'kernel_relevance': sigmas = np.asarray(model.sigmas_).astype(np.float32) sigma_min = getattr(model, 'sigma_min', 1e-3) sigmas = np.maximum(sigmas, sigma_min) neg_inv_2sigma_sq = (-1.0 / (2.0 * sigmas ** 2)).astype(np.float32) initializers.append( oh.make_tensor( 'neg_inv_2sigma_sq', TensorProto.FLOAT, list(neg_inv_2sigma_sq.shape), neg_inv_2sigma_sq.flatten().tolist(), ), ) relevances = np.asarray(model.relevances_).astype(np.float32) initializers.append( oh.make_tensor( 'relevances', TensorProto.FLOAT, list(relevances.shape), relevances.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _kernel_relevance_onnx( None, X_for_distance, 'prototypes', 'neg_inv_2sigma_sq', 'relevances', ) elif dist_type == 'kernel_exponential': omega_hat = np.asarray(model.omega_hat_).astype(np.float32) initializers.append( oh.make_tensor( 'omega_hat', TensorProto.FLOAT, list(omega_hat.shape), omega_hat.flatten().tolist(), ), ) nodes, extra_inits, dist_out = _kernel_exponential_onnx( None, X_for_distance, 'prototypes', 'omega_hat', ) else: raise NotImplementedError(f"Unknown distance type: {dist_type}") all_nodes.extend(nodes) initializers.extend(extra_inits) # --- Decision / competition nodes --- use_rejection = reject_threshold is not None if use_rejection: if model_type != 'supervised': raise ValueError( "reject_threshold is only supported for supervised models " f"(WTAC decision), got model_type='{model_type}'." ) comp_nodes, comp_inits = _wtac_with_rejection_onnx_nodes( dist_out, 'proto_labels', reject_threshold, ) elif model_type == 'supervised': comp_nodes, comp_inits = _wtac_onnx_nodes(dist_out, 'proto_labels') elif model_type == 'unsupervised': comp_nodes, comp_inits = _argmin_onnx_nodes(dist_out) elif model_type == 'oc_hard_nearest': comp_nodes, comp_inits = _oc_hard_nearest_onnx(dist_out, model) elif model_type == 'oc_gaussian_soft': comp_nodes, comp_inits = _oc_gaussian_soft_onnx(dist_out, model) elif model_type == 'oc_gaussian_ng': comp_nodes, comp_inits = _oc_gaussian_ng_onnx( dist_out, model, n_proto, ) elif model_type == 'svqocc': comp_nodes, comp_inits = _svqocc_onnx(dist_out, model, n_proto) elif model_type == 'cbc': comp_nodes, comp_inits = _cbc_onnx(dist_out, model) elif model_type == 'plvq': comp_nodes, comp_inits = _plvq_onnx(dist_out, model) else: raise NotImplementedError(f"Unknown model type: {model_type}") all_nodes.extend(comp_nodes) initializers.extend(comp_inits) # --- Output dtype --- # OC and SVQ-OCC models output int32 (matching JAX astype(jnp.int32)) if model_type in ('oc_hard_nearest', 'oc_gaussian_soft', 'oc_gaussian_ng', 'svqocc'): output_dtype = TensorProto.INT32 else: output_dtype = TensorProto.INT64 # --- Build graph --- X_input = oh.make_tensor_value_info('X', TensorProto.FLOAT, input_shape) Y_output = oh.make_tensor_value_info( 'predictions', output_dtype, [batch_dim], ) outputs = [Y_output] if use_rejection: conf_output = oh.make_tensor_value_info( 'confidence', TensorProto.FLOAT, [batch_dim], ) outputs.append(conf_output) graph = oh.make_graph( all_nodes, 'prosemble_predict', [X_input], outputs, initializer=initializers, ) onnx_model = oh.make_model(graph, opset_imports=[ oh.make_opsetid('', opset_version), ]) onnx_model.ir_version = 8 onnx_model.doc_string = ( f"Prosemble {type(model).__name__} predict function. " f"Distance: {dist_type}, decision: {model_type}." ) onnx.checker.check_model(onnx_model) if path is not None: onnx.save(onnx_model, path) return onnx_model