快速开始

备注

阅读本篇前,请确保已按照 安装指南 准备好昇腾环境及 ONNX Runtime!

本教程以一个简单的 resnet50 模型为例,讲述如何在 Ascend NPU上使用 ONNX Runtime 进行模型推理。

环境准备

安装本教程所依赖的额外必要库。

1pip install numpy Pillow onnx

模型准备

ONNX Runtime 推理需要 ONNX 格式模型作为输入,目前有以下几种主流途径获得 ONNX 模型。

  1. ONNX Model Zoo 中下载模型。

  2. 从 torch、TensorFlow 等框架导出 ONNX 模型。

  3. 使用转换工具,完成其他类型到 ONNX 模型的转换。

本教程使用的 resnet50 模型是从 ONNX Model Zoo 中直接下载的,具体的 下载链接

类别标签

类别标签用于将输出权重转换成人类可读的类别信息,具体的 下载链接

模型推理

 1import onnxruntime as ort
 2import numpy as np
 3import onnx
 4from PIL import Image
 5
 6def preprocess(image_path):
 7    img = Image.open(image_path)
 8    img = img.resize((224, 224))
 9    img = np.array(img).astype(np.float32)
10
11    img = np.transpose(img, (2, 0, 1))
12    img = img / 255.0
13    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
14    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
15    img = (img - mean) / std
16    img = np.expand_dims(img, axis=0)
17    return img
18
19def inference(model_path, img):
20    options = ort.SessionOptions()
21    providers = [
22        (
23            "CANNExecutionProvider",
24            {
25                "device_id": 0,
26                "arena_extend_strategy": "kNextPowerOfTwo",
27                "npu_mem_limit": 2 * 1024 * 1024 * 1024,
28                "op_select_impl_mode": "high_performance",
29                "optypelist_for_implmode": "Gelu",
30                "enable_cann_graph": True
31            },
32        ),
33        "CPUExecutionProvider",
34    ]
35
36    session = ort.InferenceSession(model_path, sess_options=options, providers=providers)
37    input_name = session.get_inputs()[0].name
38    output_name = session.get_outputs()[0].name
39
40    result = session.run([output_name], {input_name: img})
41    return result
42
43def display(classes_path, result):
44    with open(classes_path) as f:
45        labels = [line.strip() for line in f.readlines()]
46
47    pred_idx = np.argmax(result)
48    print(f'Predicted class: {labels[pred_idx]} ({result[0][0][pred_idx]:.4f})')
49
50if __name__ == '__main__':
51    model_path = '~/model/resnet/resnet50.onnx'
52    image_path = '~/model/resnet/cat.jpg'
53    classes_path = '~/model/resnet/imagenet_classes.txt'
54
55    img = preprocess(image_path)
56    result = inference(model_path, img)
57    display(classes_path, result)