ONNX Export

Prosemble can export fitted models to ONNX format for cross-platform deployment. An exported ONNX model reproduces the predict() output of the original model and runs anywhere ONNX Runtime is available — no JAX or prosemble dependency needed at inference time.

88 of 114 models are supported.

Installation

ONNX export requires the onnx package. For inference, install onnxruntime as well:

pip install onnx onnxruntime

Basic Usage

from prosemble.models import GLVQ
from prosemble.core.onnx_export import export_onnx

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Export to ONNX
onnx_model = export_onnx(model, path='glvq_model.onnx')

The export_onnx function accepts:

  • model — a fitted prosemble model

  • batch_size — fixed batch dimension (default 1; use -1 for dynamic batch size)

  • opset_version — ONNX opset version (default 17)

  • path — optional file path to save the ONNX model

  • reject_threshold — optional float to enable reject option (see below)

Running with ONNX Runtime

import numpy as np
import onnxruntime as ort

session = ort.InferenceSession('glvq_model.onnx')
X_test_np = np.asarray(X_test, dtype=np.float32)

onnx_preds = session.run(None, {'X': X_test_np})[0]

The ONNX model takes a single input X of shape (batch_size, n_features) and returns an integer array of predicted labels.

Full Workflow Example

import numpy as np
from prosemble.models import GMLVQ
from prosemble.datasets import load_iris_jax
from prosemble.core.onnx_export import export_onnx
import onnxruntime as ort

# Train
dataset = load_iris_jax()
X, y = dataset.input_data, dataset.target_data
model = GMLVQ(
    n_prototypes_per_class=1, max_iter=100,
    lr=0.01, latent_dim=2,
)
model.fit(X, y)

# Export
onnx_model = export_onnx(model, batch_size=-1, path='gmlvq.onnx')

# Run with ONNX Runtime
session = ort.InferenceSession('gmlvq.onnx')
X_np = np.asarray(X, dtype=np.float32)
onnx_preds = session.run(None, {'X': X_np})[0]

# Verify
jax_preds = model.predict(X)
assert np.array_equal(jax_preds, onnx_preds)

Supported Models

Family

Count

Models

Supervised LVQ (squared Euclidean)

9

GLVQ, GLVQ1, GLVQ21, LVQ1, LVQ21, MedianLVQ, CELVQ, SLVQ, RSLVQ

Supervised LVQ (global omega)

2

GMLVQ, MRSLVQ

Supervised LVQ (local omega)

2

LGMLVQ, LMRSLVQ

Supervised LVQ (relevance-weighted)

1

GRLVQ

Supervised LVQ (tangent)

1

GTLVQ

Supervised NG (squared Euclidean)

3

SRNG, CELVQ_NG, RSLVQ_NG

Supervised NG (global omega)

3

SMNG, MCELVQ_NG, MRSLVQ_NG

Supervised NG (local omega)

3

SLNG, LCELVQ_NG, LMRSLVQ_NG

Supervised NG (tangent)

2

STNG, TCELVQ_NG

Supervised DK (Gaussian kernel)

2

DKGLVQ, DKGLVQ_NG

Supervised DK (relevance kernel)

2

DKGRLVQ, DKGRLVQ_NG

Supervised DK (exponential kernel)

2

DKGMLVQ, DKGMLVQ_NG

One-class DK (Gaussian kernel)

2

OCDKGLVQ, OCDKGLVQ_NG

One-class DK (relevance kernel)

2

OCDKGRLVQ, OCDKGRLVQ_NG

One-class DK (exponential kernel)

2

OCDKGMLVQ, OCDKGMLVQ_NG

Unsupervised

4

NeuralGas, GrowingNeuralGas, KohonenSOM, HeskesSOM

Unsupervised DK

3

DKNeuralGas, DKKohonenSOM, DKHeskesSOM

Fuzzy clustering

8

FCM, PCM, FPCM, PFCM, AFCM, HCM, IPCM, IPCM2

One-class GLVQ

10

OCGLVQ, OCGLVQ_NG, OCGRLVQ, OCGRLVQ_NG, OCGMLVQ, OCGMLVQ_NG, OCLGMLVQ, OCLGMLVQ_NG, OCGTLVQ, OCGTLVQ_NG

One-class RSLVQ

6

OCRSLVQ, OCRSLVQ_NG, OCMRSLVQ, OCMRSLVQ_NG, OCLMRSLVQ, OCLMRSLVQ_NG

SVQ-OCC

5

SVQOCC, SVQOCC_R, SVQOCC_M, SVQOCC_LM, SVQOCC_T

MLP encoder + WTAC

4

SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN

CNN encoder + WTAC

3

ImageGLVQ, ImageGMLVQ, ImageGTLVQ

PLVQ (Gaussian mixture)

1

PLVQ

CBC (reasoning matrices)

2

CBC, ImageCBC

Riemannian SO(n) (chordal)

1

RiemannianSRNG

Riemannian SO(n) (tangent-space metric)

3

RiemannianSMNG, RiemannianSLNG, RiemannianSTNG

Riemannian Grassmannian (tangent-space metric)

(same 3)

RiemannianSMNG, RiemannianSLNG, RiemannianSTNG (alternate manifold config)

Encoder Models

Models with MLP or CNN backbones (Siamese, Image, LVQMLN, PLVQ, ImageCBC) are fully supported. The ONNX graph encodes the backbone as standard ops:

  • MLP: MatMul \(\rightarrow\) Add \(\rightarrow\) Activation per layer

  • CNN: Conv (SAME padding) \(\rightarrow\) Activation per layer \(\rightarrow\) GlobalAveragePool \(\rightarrow\) Linear

Prototypes are pre-computed in the latent space at export time, so only the input needs to be encoded at runtime. Supported activations: Sigmoid, ReLU, Tanh, LeakyReLU, SELU.

Distance Functions in ONNX

Distance

ONNX Implementation

Models

Squared Euclidean

Expansion trick

GLVQ family, NG, NeuralGas, FCM, OC, SVQ-OCC, Encoders, CBC

Global Omega

Project then squared Euclidean

GMLVQ, MRSLVQ, SMNG, OCGMLVQ, SVQOCC_M

Local Omega

Batched MatMul per prototype

LGMLVQ, SLNG, OCLGMLVQ, SVQOCC_LM

Tangent

Batched project-reconstruct

GTLVQ, STNG, OCGTLVQ, SVQOCC_T

Relevance-weighted

Element-wise weighted squared diff

GRLVQ, OCGRLVQ, SVQOCC_R

Gaussian Kernel (per-prototype \(\sigma\))

\(2(1 - \exp(-\|x-w\|^2 / 2\sigma_k^2))\)

DKGLVQ, DKGLVQ_NG, OCDKGLVQ, OCDKGLVQ_NG

Relevance Kernel

\(2(1 - \exp(-\sum_j \lambda_j(x_j - w_j)^2 / 2\sigma_k^2))\)

DKGRLVQ, DKGRLVQ_NG, OCDKGRLVQ, OCDKGRLVQ_NG

Exponential Kernel

\(\exp(x^\top\hat\Lambda x) + \exp(w^\top\hat\Lambda w) - 2\exp(x^\top\hat\Lambda w)\)

DKGMLVQ, DKGMLVQ_NG, OCDKGMLVQ, OCDKGMLVQ_NG

SO(n) Chordal

Broadcast subtract + Frobenius norm

RiemannianSRNG

SO(n) Tangent

Skew-symmetric log map + metric adaptation

RiemannianSMNG/SLNG/STNG (SO)

Grassmannian Tangent

Projection log map + metric adaptation

RiemannianSMNG/SLNG/STNG (Gr)

Not Supported

Models

Count

Reason

Kernel fuzzy clustering (KFCM, KPCM, KFPCM, KPFCM, KAFCM, KIPCM, KIPCM2)

7

Kernel fuzzy membership requires iterative kernel-distance update incompatible with ONNX

RiemannianNeuralGas

1

Matrix logarithm via Schur decomposition has no ONNX operator

Riemannian models + SPD(n) manifold

(config)

Eigendecomposition (eigh) has no ONNX operator

RiemannianSRNG + Grassmannian manifold

(config)

SVD-based geodesic distance has no ONNX operator

KNN

1

k-nearest-neighbor logic, not prototype-based

NPC

1

Different predict pattern

Utility (KMeansPlusPlus, Kmeans, SOM, BGPC)

4

Not prototype model base classes

Export with Reject Option

Supervised models can be exported with a built-in reject option. When reject_threshold is provided, the ONNX model outputs two tensors:

  • predictions (INT64): class labels with -1 for rejected samples

  • confidence (FLOAT): confidence scores in [-1, 1]

The confidence is the GLVQ relative margin:

\[\text{confidence}(x) = \frac{d^-(x) - d^+(x)}{d^-(x) + d^+(x)}\]

Samples with confidence below the threshold are rejected (prediction = -1).

from prosemble.models import GLVQ
from prosemble.core.onnx_export import export_onnx

model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)

# Export with reject option (threshold = 0.2)
onnx_model = export_onnx(
    model, batch_size=-1, reject_threshold=0.2, path='glvq_reject.onnx',
)

# Run with ONNX Runtime
import onnxruntime as ort
import numpy as np

session = ort.InferenceSession('glvq_reject.onnx')
results = session.run(None, {'X': X_test_np})

predictions = results[0]   # class labels or -1
confidence = results[1]    # confidence scores

# Analyze
rejected = predictions == -1
print(f"Rejected: {rejected.sum()} / {len(predictions)}")
print(f"Accuracy on accepted: {(predictions[~rejected] == y_test[~rejected]).mean():.3f}")

Works with all supervised model distance types: Euclidean, relevance-weighted, omega-projected, local omega, tangent, and all differentiating kernel variants.

Not supported for one-class or unsupervised models (raises ValueError).