Onnx Runtime Training (ORT)

Prepare the envionement

For this application example you will need to have a Python 3.11.2 environement wiht onnxruntime-training, numpy, torch, ’torchvision’ and daiedge-vlab installed. Use the step-by-step user guide to install and setup the daiedge-vlab package. Use the following requirements.txt file to install the other depedencies:

pip install -r odt_requirements.txt

Setup file

Create a setup.yaml file and fill your creatential in order to be able to use the dAIEdgeVLabAPI.

api:
    url: "vlab.daiedge.eu"
    port: "443"
user : 
    email: "ABC@abc.ai"
    password: "XYZ"

Prepare the different datasets

From the MNIST dataset we create 2 sub-datasets :

  • A train dataset (used to train the model)
  • A test dataset (used to test the model)
from onnxruntime.training import artifacts
from daiedge_vlab import dAIEdgeVLabAPI, OnDeviceTrainingConfig
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import onnx
import torch
import torch.nn.functional as F

############################################
# Configuration
############################################

SETUP_FILE = "setup.yaml"

TARGET = 'rpi5'
RUNTIME = 'ort'
MODEL = 'artifacts.zip'

TRAIN_NB_SAMPLES = 10000
TEST_NB_SAMPLES = 2000

RESULT_DIR = "./results"

# ############################################
# # Step 1: Load, Preprocess MNIST, and Save Data
# ############################################

def prepare_datasets():
    # Load MNIST dataset.
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST(root=".", train=True, download=True, transform=transform)
    test_dataset  = datasets.MNIST(root=".", train=False, download=True, transform=transform)

    # Convert datasets to numpy arrays
    x_train = np.stack([np.array(img).squeeze() for img, _ in train_dataset])
    y_train = np.array([label for _, label in train_dataset])
    x_test  = np.stack([np.array(img).squeeze() for img, _ in test_dataset])
    y_test  = np.array([label for _, label in test_dataset])

    # Preprocess: Convert images to float32, normalize to [0,1], and add channel dimension.
    x_train = (x_train.astype(np.float32))[..., np.newaxis]
    x_test  = (x_test.astype(np.float32))[..., np.newaxis]

    # Reduce the size of the dataset for faster training (optional).
    x_train = x_train[:TRAIN_NB_SAMPLES]  
    y_train = y_train[:TRAIN_NB_SAMPLES]
    x_test  = x_test[:TEST_NB_SAMPLES]    
    y_test  = y_test[:TEST_NB_SAMPLES]

    # Preprocess labels.
    y_train = y_train.astype(np.int64)
    y_test = y_test.astype(np.int64)

    # Save preprocessed data as binary files.
    x_train.tofile("mnist_train_input.bin")
    y_train.tofile("mnist_train_target.bin")
    x_test.tofile("mnist_test_input.bin")
    y_test.tofile("mnist_test_target.bin")

    print(len(x_train), len(y_train), len(x_test), len(y_test))
    print("Preprocessed train/test data saved as .bin files.")
    return (x_train, y_train), (x_test, y_test)

train_set, test_set = prepare_datasets()

Define the model architecture

The following code defines a simple model that will be used for on device training. The model is a Convolutional Neural Network (CNN) suitable for image classification tasks like MNIST digit recognition.

############################################
# Step 2: Define a Simple Keras Model.
############################################

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # input channels = 1, output = 32
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # 28x28 → 26x26 → 13x13 → 11x11 → 5x5
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # [B, 32, 26, 26]
        x = self.pool1(x)          # [B, 32, 13, 13]
        x = F.relu(self.conv2(x))  # [B, 64, 11, 11]
        x = self.pool2(x)          # [B, 64, 5, 5]
        x = self.flatten(x)        # [B, 64*5*5]
        x = F.relu(self.fc1(x))    # [B, 128]
        x = F.softmax(self.fc2(x), dim=1)  # [B, 10]
        return x
    
model = CNNModel()

Config.json

Create a config.json file specifying some informations needed by the dAIEdgeVLabAPI.

{
    "learning_parameters" : {
      "batch_size": 4,
      "epochs": 3,
      "loss_function": "sparse_categorical_crossentropy"
    },
    "input": {
      "name": "input",
      "shape": [1, 28, 28],
      "dtype": "float32",
      "train_file": "mnist_train_input.bin",
      "test_file": "mnist_test_input.bin"
    },
    "output": {
      "name": "output",
      "shape": [],
      "dtype": "int32",
      "train_file": "mnist_train_target.bin",
      "test_file": "mnist_test_target.bin"
    }
}

Convert Torch Model to Onnx

Onnx runtime training needs model in the ONNX format.

############################################
# Step 3: Convert model to ONNX
############################################

torch.onnx.export(
    model,
    torch.randn(1,1, 28,28),
    f"temp/model.onnx",
    input_names=["input"], output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

Generates training artifacts

From the Onnx model, ORT generates artifacts; files used for training. This is where the trainable wheights, the optimizer and the loss function are specified.

############################################
# Step 4: Generate artifacts
############################################

requires_grad = []
onnx_model = onnx.load(f"temp/model.onnx")
for node in reversed(onnx_model.graph.node):
    if node.op_type == "Conv" or node.op_type == "Gemm":
            _, second, third = node.input
            requires_grad.extend([second, third])

frozen_params = [
    param.name
    for param in onnx_model.graph.initializer
    if param.name not in requires_grad
]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="temp",
    additional_output_names=["output"])

Zip artifacts

Since the VLAB expect a single file as the model, zip all generated artifacts in a artifacts.zip file.

############################################
# Step 5: Zip the files
############################################
import zipfile

files_to_zip = ["temp/checkpoint", "temp/eval_model.onnx", "temp/optimizer_model.onnx", "temp/training_model.onnx"]

with zipfile.ZipFile("artifacts.zip", "w") as zipf:
    for file in files_to_zip:
        zipf.write(file)

Start benchmarking using VLAB API

Once the required files (setup.yaml, config.json, the dataset and the artifacts.zip) are available, start the benchmark using the dAIEdgeVLabAPI.

############################################
# Step 6: Upload the Model and Start the Benchmark.
############################################


# Log in to the dAIEdge VLab API.
api = dAIEdgeVLabAPI("setup.yaml")

# Parse the configuration file for on-device training.
# Make sure to adjust the path to your config file.
config = OnDeviceTrainingConfig("./config.json")

# Start the benchmark with the specified target, runtime, dataset and model.
id = api.startOdtBenchmark(
    TARGET,
    RUNTIME,
    MODEL,
    config
)

# Wait for the benchmark to finish and save the results.
print(f"Benchmark started with ID: {id}")

r = api.waitBenchmarkResult(id, save_path=RESULT_DIR, verbose=True)
print("Benchmark finished:", r)

The benchmark results are available throught the VLAB web interface in the history section.