Onnx Runtime Training (ORT)

Prepare the envionement

For this application example you will need to have a Python 3.11.2 environement wiht onnxruntime-training, numpy, torch, ’torchvision’ and daiedge-vlab installed. Use the step-by-step user guide to install and setup the daiedge-vlab package. Use the following requirements.txt file to install the other depedencies:

pip install -r odt_requirements.txt

Setup file

Create a setup.yaml file and fill your creatential in order to be able to use the dAIEdgeVLabAPI.

api:
    url: "vlab.daiedge.eu"
    port: "443"
user : 
    email: "ABC@abc.ai"
    password: "XYZ"

Prepare the different datasets

From the MNIST dataset we create 2 sub-datasets :

  • A train dataset (used to train the model)
  • A test dataset (used to test the model)
from onnxruntime.training import artifacts
from daiedge_vlab import dAIEdgeVLabAPI, OnDeviceTrainingConfig
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import onnx
import torch
import torch.nn.functional as F

############################################
# Configuration
############################################

SETUP_FILE = "setup.yaml"

TARGET = 'rpi5'
RUNTIME = 'ort'
MODEL = 'artifacts.zip'

TRAIN_NB_SAMPLES = 10000
TEST_NB_SAMPLES = 2000

RESULT_DIR = "./results"

# ############################################
# # Step 1: Load, Preprocess MNIST, and Save Data
# ############################################

def prepare_datasets():
    # Load MNIST dataset.
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST(root=".", train=True, download=True, transform=transform)
    test_dataset  = datasets.MNIST(root=".", train=False, download=True, transform=transform)

    # Convert datasets to numpy arrays
    x_train = np.stack([np.array(img).squeeze() for img, _ in train_dataset])
    y_train = np.array([label for _, label in train_dataset])
    x_test  = np.stack([np.array(img).squeeze() for img, _ in test_dataset])
    y_test  = np.array([label for _, label in test_dataset])

    # Preprocess: Convert images to float32, normalize to [0,1], and add channel dimension.
    x_train = (x_train.astype(np.float32))[..., np.newaxis]
    x_test  = (x_test.astype(np.float32))[..., np.newaxis]

    # Reduce the size of the dataset for faster training (optional).
    x_train = x_train[:TRAIN_NB_SAMPLES]  
    y_train = y_train[:TRAIN_NB_SAMPLES]
    x_test  = x_test[:TEST_NB_SAMPLES]    
    y_test  = y_test[:TEST_NB_SAMPLES]

    # Preprocess labels.
    y_train = y_train.astype(np.int64)
    y_test = y_test.astype(np.int64)

    # Save preprocessed data as binary files.
    x_train.tofile("mnist_train_input.bin")
    y_train.tofile("mnist_train_target.bin")
    x_test.tofile("mnist_test_input.bin")
    y_test.tofile("mnist_test_target.bin")

    print(len(x_train), len(y_train), len(x_test), len(y_test))
    print("Preprocessed train/test data saved as .bin files.")
    return (x_train, y_train), (x_test, y_test)

train_set, test_set = prepare_datasets()

Define the model architecture

The following code defines a simple model that will be used for on device training. The model is a Convolutional Neural Network (CNN) suitable for image classification tasks like MNIST digit recognition.

############################################
# Step 2: Define a Simple Keras Model.
############################################

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # input channels = 1, output = 32
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # 28x28 → 26x26 → 13x13 → 11x11 → 5x5
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # [B, 32, 26, 26]
        x = self.pool1(x)          # [B, 32, 13, 13]
        x = F.relu(self.conv2(x))  # [B, 64, 11, 11]
        x = self.pool2(x)          # [B, 64, 5, 5]
        x = self.flatten(x)        # [B, 64*5*5]
        x = F.relu(self.fc1(x))    # [B, 128]
        x = F.softmax(self.fc2(x), dim=1)  # [B, 10]
        return x
    
model = CNNModel()

Config.json

Create a config.json file specifying some informations needed by the dAIEdgeVLabAPI.

{
    "learning_parameters" : {
      "batch_size": 4,
      "epochs": 3,
      "loss_function": "sparse_categorical_crossentropy"
    },
    "input": {
      "name": "input",
      "shape": [1, 28, 28],
      "dtype": "float32",
      "train_file": "mnist_train_input.bin",
      "test_file": "mnist_test_input.bin"
    },
    "output": {
      "name": "output",
      "shape": [],
      "dtype": "int32",
      "train_file": "mnist_train_target.bin",
      "test_file": "mnist_test_target.bin"
    }
}

Convert Torch Model to Onnx

Onnx runtime training needs model in the ONNX format.

############################################
# Step 3: Convert model to ONNX
############################################

torch.onnx.export(
    model,
    torch.randn(1,1, 28,28),
    f"temp/model.onnx",
    input_names=["input"], output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

Generates training artifacts

From the Onnx model, ORT generates artifacts; files used for training. This is where the trainable wheights, the optimizer and the loss function are specified.

############################################
# Step 4: Generate artifacts
############################################

requires_grad = []
onnx_model = onnx.load(f"temp/model.onnx")
for node in reversed(onnx_model.graph.node):
    if node.op_type == "Conv" or node.op_type == "Gemm":
            _, second, third = node.input
            requires_grad.extend([second, third])

frozen_params = [
    param.name
    for param in onnx_model.graph.initializer
    if param.name not in requires_grad
]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="temp",
    additional_output_names=["output"])

Zip artifacts

Since the VLAB expect a single file as the model, zip all generated artifacts in a artifacts.zip file.

############################################
# Step 5: Zip the files
############################################
import zipfile

files_to_zip = ["temp/checkpoint", "temp/eval_model.onnx", "temp/optimizer_model.onnx", "temp/training_model.onnx"]

with zipfile.ZipFile("artifacts.zip", "w") as zipf:
    for file in files_to_zip:
        zipf.write(file)

Start benchmarking using VLAB API

Once the required files (setup.yaml, config.json, the dataset and the artifacts.zip) are available, start the benchmark using the dAIEdgeVLabAPI.

############################################
# Step 6: Upload the Model and Start the Benchmark.
############################################


# Log in to the dAIEdge VLab API.
api = dAIEdgeVLabAPI("setup.yaml")

# Parse the configuration file for on-device training.
# Make sure to adjust the path to your config file.
config = OnDeviceTrainingConfig("./config.json")

# Start the benchmark with the specified target, runtime, dataset and model.
id = api.startOdtBenchmark(
    TARGET,
    RUNTIME,
    MODEL,
    config
)

# Wait for the benchmark to finish and save the results.
print(f"Benchmark started with ID: {id}")

r = api.waitBenchmarkResult(id, save_path=RESULT_DIR, verbose=True)
print("Benchmark finished:", r)

Analyze the results

Plot the losses evolution.


############################################
# Step 7: Plot the metrics gathered during the on-device training.
############################################

from daiedge_vlab import dAIEdgeVLabAPI
import os
import numpy as np
import matplotlib.pyplot as plt

BENCHMARK_ID = # Fill your benchmark ID

api = dAIEdgeVLabAPI("setup.yaml")
r = api.getBenchmarkResult(BENCHMARK_ID)

# data lists (epoch, value)
train_loss = [(l["epoch_index"], l["loss"])  for l in r["report"]["loss_train"]["epochs"]]
test_loss  = [(l["epoch_index"], l["loss"])  for l in r["report"]["loss_test"]["epochs"]]
train_time = [(l["epoch_index"], l["time"])  for l in r["report"]["loss_train"]["epochs"]]

# unpack data lists
# epochs, tr_loss_vals, te_loss_vals, tr_time_vals
epochs,  tr_loss_vals = zip(*train_loss)
_,       te_loss_vals = zip(*test_loss)
_,       tr_time_vals = zip(*train_time)
tr_avrge_time = np.mean(tr_time_vals)

# side-by-side figure
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), sharex=True)

# left panel - losses
axes[0].plot(epochs, tr_loss_vals, label="Train Loss")
axes[0].plot(epochs, te_loss_vals, label="Test Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Loss over Epochs")
axes[0].legend()

# right panel - time per epoch
axes[1].plot(epochs, tr_time_vals, label="Train Time", color="orange")
axes[1].axhline(tr_avrge_time, color="red", linestyle="--", label="Avg Time")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Time (s)")
axes[1].set_title("Training Time per Epoch")
axes[1].legend()

fig.tight_layout()
fig.savefig("training_metrics.png", dpi=150)

Retrieve the trained model and test it

The resulting trained model is saved in the “RESULT_DIR” folder specified when launching the benchmarking. The name of the file is retrieved from the benchmarking results. Run inferences and calculate accuracy using the trained model.

############################################
# Step 8: Restore the Model and Evaluate on Test Set.
############################################

from daiedge_vlab import dAIEdgeVLabAPI
import os
import numpy as np
from onnxruntime import InferenceSession

BENCHMARK_ID = # Fill your benchmark ID

api = dAIEdgeVLabAPI("setup.yaml")
r = api.getBenchmarkResult(BENCHMARK_ID)

model_name = r['model_output']
model_path = f"{RESULT_DIR}/{model_name}"  

test_input_file   = "mnist_test_input.bin"
test_target_file  = "mnist_test_target.bin"

# Verify files exist
for path in (test_input_file,
                 test_target_file):
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Data file not found: {path}")

# Expected shapes and dtypes
in_shape  = (1,1,28,28)
out_shape = ()
in_dtype  = np.dtype("float32")
out_dtype = np.dtype("int64")

print(f"Input shape: {in_shape}, dtype: {in_dtype}")
print(f"Output shape: {out_shape}, dtype: {out_dtype}")

# Load and reshape

x_test = np.fromfile(test_input_file, dtype=in_dtype).reshape((-1,) + in_shape)
y_test = np.fromfile(test_target_file, dtype=out_dtype)
y_test = (y_test.reshape((-1,) + out_shape)
            if out_shape else y_test.reshape((-1,)))


session = InferenceSession(model_path)
input_name = session.get_inputs()[0].name

correct_predictions = 0
total_samples = 0

for image, label in zip(x_test, y_test):

    # Preprocess the input image
    input_image = image.astype(np.float32)  # Convert tensor to numpy array
    #input_image = [image.reshape(len(image),784).numpy(),image.numpy().astype(np.int64)]

    # Run inference
    outputs = session.run(None, {input_name: input_image})

    # Get the predicted label
    predicted_label = np.argmax(outputs[0], axis=1)

    # Check if the prediction is correct
    if predicted_label[0] == label.item():
        correct_predictions += 1

    total_samples += 1

# Calculate accuracy
accuracy = correct_predictions / total_samples * 100
print(f"Accuracy calculated on {total_samples} images.")
print(f"Accuracy: {accuracy:.2f}%")