ONNX Export¶
Prosemble can export fitted models to ONNX format for
cross-platform deployment. An exported ONNX model reproduces the predict()
output of the original model and runs anywhere ONNX Runtime is available —
no JAX or prosemble dependency needed at inference time.
88 of 114 models are supported.
Installation¶
ONNX export requires the onnx package. For inference, install
onnxruntime as well:
pip install onnx onnxruntime
Basic Usage¶
from prosemble.models import GLVQ
from prosemble.core.onnx_export import export_onnx
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
# Export to ONNX
onnx_model = export_onnx(model, path='glvq_model.onnx')
The export_onnx function accepts:
model— a fitted prosemble modelbatch_size— fixed batch dimension (default1; use-1for dynamic batch size)opset_version— ONNX opset version (default17)path— optional file path to save the ONNX modelreject_threshold— optional float to enable reject option (see below)
Running with ONNX Runtime¶
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession('glvq_model.onnx')
X_test_np = np.asarray(X_test, dtype=np.float32)
onnx_preds = session.run(None, {'X': X_test_np})[0]
The ONNX model takes a single input X of shape (batch_size, n_features)
and returns an integer array of predicted labels.
Full Workflow Example¶
import numpy as np
from prosemble.models import GMLVQ
from prosemble.datasets import load_iris_jax
from prosemble.core.onnx_export import export_onnx
import onnxruntime as ort
# Train
dataset = load_iris_jax()
X, y = dataset.input_data, dataset.target_data
model = GMLVQ(
n_prototypes_per_class=1, max_iter=100,
lr=0.01, latent_dim=2,
)
model.fit(X, y)
# Export
onnx_model = export_onnx(model, batch_size=-1, path='gmlvq.onnx')
# Run with ONNX Runtime
session = ort.InferenceSession('gmlvq.onnx')
X_np = np.asarray(X, dtype=np.float32)
onnx_preds = session.run(None, {'X': X_np})[0]
# Verify
jax_preds = model.predict(X)
assert np.array_equal(jax_preds, onnx_preds)
Supported Models¶
Family |
Count |
Models |
|---|---|---|
Supervised LVQ (squared Euclidean) |
9 |
GLVQ, GLVQ1, GLVQ21, LVQ1, LVQ21, MedianLVQ, CELVQ, SLVQ, RSLVQ |
Supervised LVQ (global omega) |
2 |
GMLVQ, MRSLVQ |
Supervised LVQ (local omega) |
2 |
LGMLVQ, LMRSLVQ |
Supervised LVQ (relevance-weighted) |
1 |
GRLVQ |
Supervised LVQ (tangent) |
1 |
GTLVQ |
Supervised NG (squared Euclidean) |
3 |
SRNG, CELVQ_NG, RSLVQ_NG |
Supervised NG (global omega) |
3 |
SMNG, MCELVQ_NG, MRSLVQ_NG |
Supervised NG (local omega) |
3 |
SLNG, LCELVQ_NG, LMRSLVQ_NG |
Supervised NG (tangent) |
2 |
STNG, TCELVQ_NG |
Supervised DK (Gaussian kernel) |
2 |
DKGLVQ, DKGLVQ_NG |
Supervised DK (relevance kernel) |
2 |
DKGRLVQ, DKGRLVQ_NG |
Supervised DK (exponential kernel) |
2 |
DKGMLVQ, DKGMLVQ_NG |
One-class DK (Gaussian kernel) |
2 |
OCDKGLVQ, OCDKGLVQ_NG |
One-class DK (relevance kernel) |
2 |
OCDKGRLVQ, OCDKGRLVQ_NG |
One-class DK (exponential kernel) |
2 |
OCDKGMLVQ, OCDKGMLVQ_NG |
Unsupervised |
4 |
NeuralGas, GrowingNeuralGas, KohonenSOM, HeskesSOM |
Unsupervised DK |
3 |
DKNeuralGas, DKKohonenSOM, DKHeskesSOM |
Fuzzy clustering |
8 |
FCM, PCM, FPCM, PFCM, AFCM, HCM, IPCM, IPCM2 |
One-class GLVQ |
10 |
OCGLVQ, OCGLVQ_NG, OCGRLVQ, OCGRLVQ_NG, OCGMLVQ, OCGMLVQ_NG, OCLGMLVQ, OCLGMLVQ_NG, OCGTLVQ, OCGTLVQ_NG |
One-class RSLVQ |
6 |
OCRSLVQ, OCRSLVQ_NG, OCMRSLVQ, OCMRSLVQ_NG, OCLMRSLVQ, OCLMRSLVQ_NG |
SVQ-OCC |
5 |
SVQOCC, SVQOCC_R, SVQOCC_M, SVQOCC_LM, SVQOCC_T |
MLP encoder + WTAC |
4 |
SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN |
CNN encoder + WTAC |
3 |
ImageGLVQ, ImageGMLVQ, ImageGTLVQ |
PLVQ (Gaussian mixture) |
1 |
PLVQ |
CBC (reasoning matrices) |
2 |
CBC, ImageCBC |
Riemannian SO(n) (chordal) |
1 |
RiemannianSRNG |
Riemannian SO(n) (tangent-space metric) |
3 |
RiemannianSMNG, RiemannianSLNG, RiemannianSTNG |
Riemannian Grassmannian (tangent-space metric) |
(same 3) |
RiemannianSMNG, RiemannianSLNG, RiemannianSTNG (alternate manifold config) |
Encoder Models¶
Models with MLP or CNN backbones (Siamese, Image, LVQMLN, PLVQ, ImageCBC) are fully supported. The ONNX graph encodes the backbone as standard ops:
MLP:
MatMul\(\rightarrow\)Add\(\rightarrow\)Activationper layerCNN:
Conv(SAME padding) \(\rightarrow\)Activationper layer \(\rightarrow\)GlobalAveragePool\(\rightarrow\)Linear
Prototypes are pre-computed in the latent space at export time, so only the input needs to be encoded at runtime. Supported activations: Sigmoid, ReLU, Tanh, LeakyReLU, SELU.
Distance Functions in ONNX¶
Distance |
ONNX Implementation |
Models |
|---|---|---|
Squared Euclidean |
Expansion trick |
GLVQ family, NG, NeuralGas, FCM, OC, SVQ-OCC, Encoders, CBC |
Global Omega |
Project then squared Euclidean |
GMLVQ, MRSLVQ, SMNG, OCGMLVQ, SVQOCC_M |
Local Omega |
Batched MatMul per prototype |
LGMLVQ, SLNG, OCLGMLVQ, SVQOCC_LM |
Tangent |
Batched project-reconstruct |
GTLVQ, STNG, OCGTLVQ, SVQOCC_T |
Relevance-weighted |
Element-wise weighted squared diff |
GRLVQ, OCGRLVQ, SVQOCC_R |
Gaussian Kernel (per-prototype \(\sigma\)) |
\(2(1 - \exp(-\|x-w\|^2 / 2\sigma_k^2))\) |
DKGLVQ, DKGLVQ_NG, OCDKGLVQ, OCDKGLVQ_NG |
Relevance Kernel |
\(2(1 - \exp(-\sum_j \lambda_j(x_j - w_j)^2 / 2\sigma_k^2))\) |
DKGRLVQ, DKGRLVQ_NG, OCDKGRLVQ, OCDKGRLVQ_NG |
Exponential Kernel |
\(\exp(x^\top\hat\Lambda x) + \exp(w^\top\hat\Lambda w) - 2\exp(x^\top\hat\Lambda w)\) |
DKGMLVQ, DKGMLVQ_NG, OCDKGMLVQ, OCDKGMLVQ_NG |
SO(n) Chordal |
Broadcast subtract + Frobenius norm |
RiemannianSRNG |
SO(n) Tangent |
Skew-symmetric log map + metric adaptation |
RiemannianSMNG/SLNG/STNG (SO) |
Grassmannian Tangent |
Projection log map + metric adaptation |
RiemannianSMNG/SLNG/STNG (Gr) |
Not Supported¶
Models |
Count |
Reason |
|---|---|---|
Kernel fuzzy clustering (KFCM, KPCM, KFPCM, KPFCM, KAFCM, KIPCM, KIPCM2) |
7 |
Kernel fuzzy membership requires iterative kernel-distance update incompatible with ONNX |
RiemannianNeuralGas |
1 |
Matrix logarithm via Schur decomposition has no ONNX operator |
Riemannian models + SPD(n) manifold |
(config) |
Eigendecomposition (eigh) has no ONNX operator |
RiemannianSRNG + Grassmannian manifold |
(config) |
SVD-based geodesic distance has no ONNX operator |
KNN |
1 |
k-nearest-neighbor logic, not prototype-based |
NPC |
1 |
Different predict pattern |
Utility (KMeansPlusPlus, Kmeans, SOM, BGPC) |
4 |
Not prototype model base classes |
Export with Reject Option¶
Supervised models can be exported with a built-in reject option. When
reject_threshold is provided, the ONNX model outputs two tensors:
predictions(INT64): class labels with-1for rejected samplesconfidence(FLOAT): confidence scores in[-1, 1]
The confidence is the GLVQ relative margin:
Samples with confidence below the threshold are rejected (prediction = -1).
from prosemble.models import GLVQ
from prosemble.core.onnx_export import export_onnx
model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01)
model.fit(X_train, y_train)
# Export with reject option (threshold = 0.2)
onnx_model = export_onnx(
model, batch_size=-1, reject_threshold=0.2, path='glvq_reject.onnx',
)
# Run with ONNX Runtime
import onnxruntime as ort
import numpy as np
session = ort.InferenceSession('glvq_reject.onnx')
results = session.run(None, {'X': X_test_np})
predictions = results[0] # class labels or -1
confidence = results[1] # confidence scores
# Analyze
rejected = predictions == -1
print(f"Rejected: {rejected.sum()} / {len(predictions)}")
print(f"Accuracy on accepted: {(predictions[~rejected] == y_test[~rejected]).mean():.3f}")
Works with all supervised model distance types: Euclidean, relevance-weighted, omega-projected, local omega, tangent, and all differentiating kernel variants.
Not supported for one-class or unsupervised models (raises ValueError).