ONNX Export =========== Prosemble can export fitted models to `ONNX `_ format for cross-platform deployment. An exported ONNX model reproduces the ``predict()`` output of the original model and runs anywhere ONNX Runtime is available — no JAX or prosemble dependency needed at inference time. **88 of 114 models** are supported. Installation ------------ ONNX export requires the ``onnx`` package. For inference, install ``onnxruntime`` as well: .. code-block:: bash pip install onnx onnxruntime Basic Usage ----------- .. code-block:: python from prosemble.models import GLVQ from prosemble.core.onnx_export import export_onnx model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01) model.fit(X_train, y_train) # Export to ONNX onnx_model = export_onnx(model, path='glvq_model.onnx') The ``export_onnx`` function accepts: - ``model`` — a fitted prosemble model - ``batch_size`` — fixed batch dimension (default ``1``; use ``-1`` for dynamic batch size) - ``opset_version`` — ONNX opset version (default ``17``) - ``path`` — optional file path to save the ONNX model - ``reject_threshold`` — optional float to enable reject option (see below) Running with ONNX Runtime -------------------------- .. code-block:: python import numpy as np import onnxruntime as ort session = ort.InferenceSession('glvq_model.onnx') X_test_np = np.asarray(X_test, dtype=np.float32) onnx_preds = session.run(None, {'X': X_test_np})[0] The ONNX model takes a single input ``X`` of shape ``(batch_size, n_features)`` and returns an integer array of predicted labels. Full Workflow Example --------------------- .. code-block:: python import numpy as np from prosemble.models import GMLVQ from prosemble.datasets import load_iris_jax from prosemble.core.onnx_export import export_onnx import onnxruntime as ort # Train dataset = load_iris_jax() X, y = dataset.input_data, dataset.target_data model = GMLVQ( n_prototypes_per_class=1, max_iter=100, lr=0.01, latent_dim=2, ) model.fit(X, y) # Export onnx_model = export_onnx(model, batch_size=-1, path='gmlvq.onnx') # Run with ONNX Runtime session = ort.InferenceSession('gmlvq.onnx') X_np = np.asarray(X, dtype=np.float32) onnx_preds = session.run(None, {'X': X_np})[0] # Verify jax_preds = model.predict(X) assert np.array_equal(jax_preds, onnx_preds) Supported Models ---------------- .. list-table:: :header-rows: 1 :widths: 35 15 50 * - Family - Count - Models * - Supervised LVQ (squared Euclidean) - 9 - GLVQ, GLVQ1, GLVQ21, LVQ1, LVQ21, MedianLVQ, CELVQ, SLVQ, RSLVQ * - Supervised LVQ (global omega) - 2 - GMLVQ, MRSLVQ * - Supervised LVQ (local omega) - 2 - LGMLVQ, LMRSLVQ * - Supervised LVQ (relevance-weighted) - 1 - GRLVQ * - Supervised LVQ (tangent) - 1 - GTLVQ * - Supervised NG (squared Euclidean) - 3 - SRNG, CELVQ_NG, RSLVQ_NG * - Supervised NG (global omega) - 3 - SMNG, MCELVQ_NG, MRSLVQ_NG * - Supervised NG (local omega) - 3 - SLNG, LCELVQ_NG, LMRSLVQ_NG * - Supervised NG (tangent) - 2 - STNG, TCELVQ_NG * - Supervised DK (Gaussian kernel) - 2 - DKGLVQ, DKGLVQ_NG * - Supervised DK (relevance kernel) - 2 - DKGRLVQ, DKGRLVQ_NG * - Supervised DK (exponential kernel) - 2 - DKGMLVQ, DKGMLVQ_NG * - One-class DK (Gaussian kernel) - 2 - OCDKGLVQ, OCDKGLVQ_NG * - One-class DK (relevance kernel) - 2 - OCDKGRLVQ, OCDKGRLVQ_NG * - One-class DK (exponential kernel) - 2 - OCDKGMLVQ, OCDKGMLVQ_NG * - Unsupervised - 4 - NeuralGas, GrowingNeuralGas, KohonenSOM, HeskesSOM * - Unsupervised DK - 3 - DKNeuralGas, DKKohonenSOM, DKHeskesSOM * - Fuzzy clustering - 8 - FCM, PCM, FPCM, PFCM, AFCM, HCM, IPCM, IPCM2 * - One-class GLVQ - 10 - OCGLVQ, OCGLVQ_NG, OCGRLVQ, OCGRLVQ_NG, OCGMLVQ, OCGMLVQ_NG, OCLGMLVQ, OCLGMLVQ_NG, OCGTLVQ, OCGTLVQ_NG * - One-class RSLVQ - 6 - OCRSLVQ, OCRSLVQ_NG, OCMRSLVQ, OCMRSLVQ_NG, OCLMRSLVQ, OCLMRSLVQ_NG * - SVQ-OCC - 5 - SVQOCC, SVQOCC_R, SVQOCC_M, SVQOCC_LM, SVQOCC_T * - MLP encoder + WTAC - 4 - SiameseGLVQ, SiameseGMLVQ, SiameseGTLVQ, LVQMLN * - CNN encoder + WTAC - 3 - ImageGLVQ, ImageGMLVQ, ImageGTLVQ * - PLVQ (Gaussian mixture) - 1 - PLVQ * - CBC (reasoning matrices) - 2 - CBC, ImageCBC * - Riemannian SO(n) (chordal) - 1 - RiemannianSRNG * - Riemannian SO(n) (tangent-space metric) - 3 - RiemannianSMNG, RiemannianSLNG, RiemannianSTNG * - Riemannian Grassmannian (tangent-space metric) - (same 3) - RiemannianSMNG, RiemannianSLNG, RiemannianSTNG (alternate manifold config) Encoder Models -------------- Models with MLP or CNN backbones (Siamese, Image, LVQMLN, PLVQ, ImageCBC) are fully supported. The ONNX graph encodes the backbone as standard ops: - **MLP**: ``MatMul`` :math:`\rightarrow` ``Add`` :math:`\rightarrow` ``Activation`` per layer - **CNN**: ``Conv`` (SAME padding) :math:`\rightarrow` ``Activation`` per layer :math:`\rightarrow` ``GlobalAveragePool`` :math:`\rightarrow` ``Linear`` Prototypes are pre-computed in the latent space at export time, so only the input needs to be encoded at runtime. Supported activations: Sigmoid, ReLU, Tanh, LeakyReLU, SELU. Distance Functions in ONNX -------------------------- .. list-table:: :header-rows: 1 :widths: 30 40 30 * - Distance - ONNX Implementation - Models * - Squared Euclidean - Expansion trick - GLVQ family, NG, NeuralGas, FCM, OC, SVQ-OCC, Encoders, CBC * - Global Omega - Project then squared Euclidean - GMLVQ, MRSLVQ, SMNG, OCGMLVQ, SVQOCC_M * - Local Omega - Batched MatMul per prototype - LGMLVQ, SLNG, OCLGMLVQ, SVQOCC_LM * - Tangent - Batched project-reconstruct - GTLVQ, STNG, OCGTLVQ, SVQOCC_T * - Relevance-weighted - Element-wise weighted squared diff - GRLVQ, OCGRLVQ, SVQOCC_R * - Gaussian Kernel (per-prototype :math:`\sigma`) - :math:`2(1 - \exp(-\|x-w\|^2 / 2\sigma_k^2))` - DKGLVQ, DKGLVQ_NG, OCDKGLVQ, OCDKGLVQ_NG * - Relevance Kernel - :math:`2(1 - \exp(-\sum_j \lambda_j(x_j - w_j)^2 / 2\sigma_k^2))` - DKGRLVQ, DKGRLVQ_NG, OCDKGRLVQ, OCDKGRLVQ_NG * - Exponential Kernel - :math:`\exp(x^\top\hat\Lambda x) + \exp(w^\top\hat\Lambda w) - 2\exp(x^\top\hat\Lambda w)` - DKGMLVQ, DKGMLVQ_NG, OCDKGMLVQ, OCDKGMLVQ_NG * - SO(n) Chordal - Broadcast subtract + Frobenius norm - RiemannianSRNG * - SO(n) Tangent - Skew-symmetric log map + metric adaptation - RiemannianSMNG/SLNG/STNG (SO) * - Grassmannian Tangent - Projection log map + metric adaptation - RiemannianSMNG/SLNG/STNG (Gr) Not Supported ------------- .. list-table:: :header-rows: 1 :widths: 30 20 50 * - Models - Count - Reason * - Kernel fuzzy clustering (KFCM, KPCM, KFPCM, KPFCM, KAFCM, KIPCM, KIPCM2) - 7 - Kernel fuzzy membership requires iterative kernel-distance update incompatible with ONNX * - RiemannianNeuralGas - 1 - Matrix logarithm via Schur decomposition has no ONNX operator * - Riemannian models + SPD(n) manifold - (config) - Eigendecomposition (eigh) has no ONNX operator * - RiemannianSRNG + Grassmannian manifold - (config) - SVD-based geodesic distance has no ONNX operator * - KNN - 1 - k-nearest-neighbor logic, not prototype-based * - NPC - 1 - Different predict pattern * - Utility (KMeansPlusPlus, Kmeans, SOM, BGPC) - 4 - Not prototype model base classes Export with Reject Option -------------------------- Supervised models can be exported with a built-in reject option. When ``reject_threshold`` is provided, the ONNX model outputs two tensors: - ``predictions`` (INT64): class labels with ``-1`` for rejected samples - ``confidence`` (FLOAT): confidence scores in ``[-1, 1]`` The confidence is the GLVQ relative margin: .. math:: \text{confidence}(x) = \frac{d^-(x) - d^+(x)}{d^-(x) + d^+(x)} Samples with confidence below the threshold are rejected (prediction = -1). .. code-block:: python from prosemble.models import GLVQ from prosemble.core.onnx_export import export_onnx model = GLVQ(n_prototypes_per_class=2, max_iter=100, lr=0.01) model.fit(X_train, y_train) # Export with reject option (threshold = 0.2) onnx_model = export_onnx( model, batch_size=-1, reject_threshold=0.2, path='glvq_reject.onnx', ) # Run with ONNX Runtime import onnxruntime as ort import numpy as np session = ort.InferenceSession('glvq_reject.onnx') results = session.run(None, {'X': X_test_np}) predictions = results[0] # class labels or -1 confidence = results[1] # confidence scores # Analyze rejected = predictions == -1 print(f"Rejected: {rejected.sum()} / {len(predictions)}") print(f"Accuracy on accepted: {(predictions[~rejected] == y_test[~rejected]).mean():.3f}") Works with all supervised model distance types: Euclidean, relevance-weighted, omega-projected, local omega, tangent, and all differentiating kernel variants. Not supported for one-class or unsupervised models (raises ``ValueError``).