ONNX 小模型
概念
要使用PyTorch在MNIST数据集上训练一个模型并导出为ONNX格式,可以按照以下步骤进行。我们将创建一个简单的卷积神经网络(CNN),在MNIST数据集上训练,并最终将模型导出为ONNX格式。
安装依赖
确保你安装了以下依赖:
pip install torch torchvision onnx
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
定义简单的CNN模型
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7) # 展平张量以进行全连接层处理
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
实例化模型、定义损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
训练模型
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
开始训练
train_model(model, train_loader, criterion, optimizer, num_epochs=5)
导出模型为ONNX格式
dummy_input = torch.randn(1, 1, 28, 28) # MNIST图像尺寸是1x28x28
onnx_path = "mnist_model.onnx"
torch.onnx.export(model, dummy_input, onnx_path, verbose=True, input_names=['input'], output_names=['output'])
print(f"ONNX模型已导出到 {onnx_path}")
- SimpleCNN:定义了一个简单的卷积神经网络,包含两层卷积层和两层全连接层。
- train_model:用于训练模型,使用交叉熵损失函数和Adam优化器。
- torch.onnx.export:在 训练完模型后,我们用一个假输入(大小为1x1x28x28的随机张量)来将模型导出为ONNX格式。input_names 和 output_names 用于指定ONNX模型的输入和输出的名字。
准备 Triton Server 环境
首先,确保你的系统上已经安装了 Triton Server,并且支持 ONNX 模型的推理。如果没有,按照官方文档进行安装。 在 Triton 中,每个模型需要放置在一个特定的文件夹结构中。假设模型名称为 mnist_model,模型结构应该如下:
models/
└── mnist_model/
├── 1/
│ └── model.onnx
└── config.pbtxt
1/ 文件夹代表模型的版本号(可以是任意数字,代表不同版本的模型)。 model.onnx 是从 PyTorch 导出的 ONNX 模型文件。 config.pbtxt 是 Triton 模型的配置文件。
配置文件 config.pbtxt
创建模型的配置文件 config.pbtxt,内容如下:
name: "mnist_model"
platform: "onnxruntime_onnx"
max_batch_size: 64
input: [
{
name: "input"
data_type: TYPE_FP32
dims: [ 1, 28, 28 ]
}
]
output: [
{
name: "output"
data_type: TYPE_FP32
dims: [ 10 ]
}
]
name: 模型名称,保持与文件夹名称一致。 input 和 output: 定义输入、输出的名字、数据类型以及维度,确保与模型匹配。
启动 Triton Inference Server
假设 Triton Server 已安装,可以通过以下命令启动,并指定 models 文件夹:
tritonserver --model-repository=./models
Triton Server 会自动加载 models 文件夹中的模型。
用 Python 进行 API 调用推理
我们可以通过 Triton 的 HTTP/REST API 进行模型推理。首先,安装 Triton 的客户端库:
pip install tritonclient[http]
接下来,编写 Python 代码来发送推理请求:
import numpy as np
import tritonclient.http as httpclient
from PIL import Image
import torchvision.transforms as transforms
初始化Triton客户端
triton_client = httpclient.InferenceServerClient(url="localhost:8000")
预处理函数:将图像转换为模型所需格式
def preprocess_image(image_path):
image = Image.open(image_path).convert('L') # 转换为灰度图像
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image = transform(image).unsqueeze(0) # 增加批处理维度 (1, 1, 28, 28)
return image.numpy()
读取并预处理图像
image_path = "path_to_mnist_image.png"
input_data = preprocess_image(image_path)
设置输入输出
inputs = httpclient.InferInput("input", input_data.shape, "FP32")
inputs.set_data_from_numpy(input_data)
outputs = httpclient.InferRequestedOutput("output")
发送推理请求
response = triton_client.infer(model_name="mnist_model", inputs=[inputs], outputs=[outputs])
获取并解析推理结果
output_data = response.as_numpy("output")
predicted_class = np.argmax(output_data)
print(f"Predicted class: {predicted_class}")
解释
- Triton客户端: 使用 tritonclient.http.InferenceServerClient 与 Triton Inference Server 进行通信。
- 图像预处理: 使用 PIL 和 torchvision.transforms 进行图像处理,将输入图像转换为模型所需的 [1, 1, 28, 28] 形状的 NumPy 数组。
- 推理请求: 构建输入 (InferInput),并通过 infer 函数发送请求。服务器返回的输出数据用 InferRequestedOutput 获取。
- 结果解析: 使用 np.argmax 从输出向量中获取预测的数字类别。