快速开始 =========== .. note:: 阅读本篇前,请确保已按照 :doc:`安装指南 <./install>` 准备好昇腾环境及 ONNX Runtime! 本教程以一个简单的 resnet50 模型为例,讲述如何在 Ascend NPU上使用 ONNX Runtime 进行模型推理。 环境准备 ----------- 安装本教程所依赖的额外必要库。 .. code-block:: shell :linenos: pip install numpy Pillow onnx 模型准备 ----------- ONNX Runtime 推理需要 ONNX 格式模型作为输入,目前有以下几种主流途径获得 ONNX 模型。 1. 从 `ONNX Model Zoo `_ 中下载模型。 2. 从 torch、TensorFlow 等框架导出 ONNX 模型。 3. 使用转换工具,完成其他类型到 ONNX 模型的转换。 本教程使用的 resnet50 模型是从 ONNX Model Zoo 中直接下载的,具体的 `下载链接 `_ 类别标签 ----------- 类别标签用于将输出权重转换成人类可读的类别信息,具体的 `下载链接 `_ 模型推理 ----------- .. code-block:: python :linenos: import onnxruntime as ort import numpy as np import onnx from PIL import Image def preprocess(image_path): img = Image.open(image_path) img = img.resize((224, 224)) img = np.array(img).astype(np.float32) img = np.transpose(img, (2, 0, 1)) img = img / 255.0 mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) img = (img - mean) / std img = np.expand_dims(img, axis=0) return img def inference(model_path, img): options = ort.SessionOptions() providers = [ ( "CANNExecutionProvider", { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", "npu_mem_limit": 2 * 1024 * 1024 * 1024, "op_select_impl_mode": "high_performance", "optypelist_for_implmode": "Gelu", "enable_cann_graph": True }, ), "CPUExecutionProvider", ] session = ort.InferenceSession(model_path, sess_options=options, providers=providers) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name result = session.run([output_name], {input_name: img}) return result def display(classes_path, result): with open(classes_path) as f: labels = [line.strip() for line in f.readlines()] pred_idx = np.argmax(result) print(f'Predicted class: {labels[pred_idx]} ({result[0][0][pred_idx]:.4f})') if __name__ == '__main__': model_path = '~/model/resnet/resnet50.onnx' image_path = '~/model/resnet/cat.jpg' classes_path = '~/model/resnet/imagenet_classes.txt' img = preprocess(image_path) result = inference(model_path, img) display(classes_path, result)