跳到主要内容

ONNX 小模型

概念

要使用PyTorch在MNIST数据集上训练一个模型并导出为ONNX格式,可以按照以下步骤进行。我们将创建一个简单的卷积神经网络(CNN),在MNIST数据集上训练,并最终将模型导出为ONNX格式。

安装依赖

确保你安装了以下依赖:

pip install torch torchvision onnx
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

定义简单的CNN模型

class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()

def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7) # 展平张量以进行全连接层处理
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x

加载MNIST数据集

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

实例化模型、定义损失函数和优化器

model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练模型

def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

开始训练

train_model(model, train_loader, criterion, optimizer, num_epochs=5)

导出模型为ONNX格式

dummy_input = torch.randn(1, 1, 28, 28)  # MNIST图像尺寸是1x28x28
onnx_path = "mnist_model.onnx"
torch.onnx.export(model, dummy_input, onnx_path, verbose=True, input_names=['input'], output_names=['output'])

print(f"ONNX模型已导出到 {onnx_path}")
  • SimpleCNN:定义了一个简单的卷积神经网络,包含两层卷积层和两层全连接层。
  • train_model:用于训练模型,使用交叉熵损失函数和Adam优化器。
  • torch.onnx.export:在训练完模型后,我们用一个假输入(大小为1x1x28x28的随机张量)来将模型导出为ONNX格式。input_names 和 output_names 用于指定ONNX模型的输入和输出的名字。

准备 Triton Server 环境

首先,确保你的系统上已经安装了 Triton Server,并且支持 ONNX 模型的推理。如果没有,按照官方文档进行安装。 在 Triton 中,每个模型需要放置在一个特定的文件夹结构中。假设模型名称为 mnist_model,模型结构应该如下:

models/
└── mnist_model/
├── 1/
│ └── model.onnx
└── config.pbtxt

1/ 文件夹代表模型的版本号(可以是任意数字,代表不同版本的模型)。 model.onnx 是从 PyTorch 导出的 ONNX 模型文件。 config.pbtxt 是 Triton 模型的配置文件。

配置文件 config.pbtxt

创建模型的配置文件 config.pbtxt,内容如下:

name: "mnist_model"
platform: "onnxruntime_onnx"
max_batch_size: 64
input: [
{
name: "input"
data_type: TYPE_FP32
dims: [ 1, 28, 28 ]
}
]
output: [
{
name: "output"
data_type: TYPE_FP32
dims: [ 10 ]
}
]

name: 模型名称,保持与文件夹名称一致。 input 和 output: 定义输入、输出的名字、数据类型以及维度,确保与模型匹配。

启动 Triton Inference Server

假设 Triton Server 已安装,可以通过以下命令启动,并指定 models 文件夹:

tritonserver --model-repository=./models

Triton Server 会自动加载 models 文件夹中的模型。

用 Python 进行 API 调用推理

我们可以通过 Triton 的 HTTP/REST API 进行模型推理。首先,安装 Triton 的客户端库:

pip install tritonclient[http]

接下来,编写 Python 代码来发送推理请求:

import numpy as np
import tritonclient.http as httpclient
from PIL import Image
import torchvision.transforms as transforms

初始化Triton客户端

triton_client = httpclient.InferenceServerClient(url="localhost:8000")

预处理函数:将图像转换为模型所需格式

def preprocess_image(image_path):
image = Image.open(image_path).convert('L') # 转换为灰度图像
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image = transform(image).unsqueeze(0) # 增加批处理维度 (1, 1, 28, 28)
return image.numpy()

读取并预处理图像

image_path = "path_to_mnist_image.png"
input_data = preprocess_image(image_path)

设置输入输出

inputs = httpclient.InferInput("input", input_data.shape, "FP32")
inputs.set_data_from_numpy(input_data)

outputs = httpclient.InferRequestedOutput("output")

发送推理请求

response = triton_client.infer(model_name="mnist_model", inputs=[inputs], outputs=[outputs])

获取并解析推理结果

output_data = response.as_numpy("output")
predicted_class = np.argmax(output_data)
print(f"Predicted class: {predicted_class}")

解释

  • Triton客户端: 使用 tritonclient.http.InferenceServerClient 与 Triton Inference Server 进行通信。
  • 图像预处理: 使用 PIL 和 torchvision.transforms 进行图像处理,将输入图像转换为模型所需的 [1, 1, 28, 28] 形状的 NumPy 数组。
  • 推理请求: 构建输入 (InferInput),并通过 infer 函数发送请求。服务器返回的输出数据用 InferRequestedOutput 获取。
  • 结果解析: 使用 np.argmax 从输出向量中获取预测的数字类别。