# 自动调优 (Autotune) 在本节中,我们将展示使用 Triton 的 autotune 方法以自动选择最优的 kernel 配置参数。 说明: 当前Triton-Ascend autotune支持block size、multibuffer,未来还会持续增加autotune可调项。与社区相比,Triton-Ascend支持bishengir在MLIR编译过程中失败后继续测试下一组config,而非中断autotune。 ```Python import torch, torch_npu import triton import triton.language as tl def test_triton_autotune(): # 返回一组不同的 kernel 配置,用于 autotune 测试 def get_autotune_config(): return [ triton.Config({'XS': 1 * 128, 'multibuffer': True}), triton.Config({'XS': 12 * 1024, 'multibuffer': True}), triton.Config({'XS': 12 * 1024, 'multibuffer': False}), triton.Config({'XS': 8 * 1024, 'multibuffer': True}), ] # 使用 @autotune 装饰器自动选择最优的 kernel 配置 @triton.autotune( configs=get_autotune_config(), # 配置列表 key=['XS', 'multibuffer'], # 选择输入变量作为选择依据 ) @triton.jit def triton_calc_kernel( out_ptr0, in_ptr0, in_ptr1, numel, XS: tl.constexpr # 块大小,用于控制每个线程块处理多少数据 ): pid = tl.program_id(0) # 获取当前 program 的 ID idx = pid * XS + tl.arange(0, XS) # 当前线程块处理的 index 范围 msk = idx < numel # 避免越界的掩码 # 重复执行一些计算以模拟负载(并测试性能)/ Repeat computation to simulate load (for perf test) for i in range(10000): tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0) # 加载 x0 tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0) # 加载 x1 tmp2 = tl.math.exp(tmp0) + tmp1 + i # 计算 tl.store(out_ptr0 + idx, tmp2, mask=msk) # 存储到输出 # Triton 调用函数,自动使用 autotuned kernel def triton_calc_func(x0, x1): n = x0.numel() y0 = torch.empty_like(x0) grid = lambda meta: (triton.cdiv(n, meta["XS"]), 1, 1) # 计算 grid 大小 triton_calc_kernel[grid](y0, x0, x1, n) return y0 # 使用 PyTorch 作为参考实现进行对比 def torch_calc_func(x0, x1): return torch.exp(x0) + x1 + 10000 - 1 DEV = "npu" # 使用 NPU 作为设备 DTYPE = torch.float32 N = 192 * 1024 # 输入长度 x0 = torch.randn((N,), dtype=DTYPE, device=DEV) # 随机输入 x0 x1 = torch.randn((N,), dtype=DTYPE, device=DEV) # 随机输入 x1 torch_ref = torch_calc_func(x0, x1) # 得到参考结果 triton_cal = triton_calc_func(x0, x1) # 运行 Triton kernel torch.testing.assert_close(triton_cal, torch_ref) # 验证输出是否一致 if __name__ == "__main__": test_triton_autotune() print("success: test_triton_autotune") # 输出成功标志 / Print success message ```