向量相加 (Vector Addition)

在本节中,我们将使用 Triton 编写一个简单的向量相加的程序。 在此过程中,你会学习到:

  • Triton 的基本编程模式。

  • 用于定义 Triton 内核的triton.jit装饰器(decorator)。

计算内核:

import torch
import torch_npu

import triton
import triton.language as tl


@triton.jit
def add_kernel(x_ptr,  # 指向第一个输入向量的指针。
               y_ptr,  # 指向第二个输入向量的指针。
               output_ptr,  # 指向输出向量的指针。
               n_elements,  # 向量的大小。
               BLOCK_SIZE: tl.constexpr,  # 每个程序应处理的元素数量。
               # 注意:`constexpr` 将标记变量为常量。
               ):
    # 不同的数据由不同的“process”来处理,因此需要分配:
    pid = tl.program_id(axis=0)  # 使用 1D 启动网格,因此轴为 0。
    # 该程序将处理相对初始数据偏移的输入。
    # 例如,如果有一个长度为 256, 块大小为 64 的向量,程序将各自访问 [0:64, 64:128, 128:192, 192:256] 的元素。
    # 注意 offsets 是指针列表:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建掩码以防止内存操作超出边界访问。
    mask = offsets < n_elements
    # 从 DRAM 加载 x 和 y,如果输入不是块大小的整数倍,则屏蔽掉任何多余的元素。
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # 将 x + y 写回 DRAM。
    tl.store(output_ptr + offsets, output, mask=mask)

创建一个辅助函数用于:

  • 生成 z 张量;

  • 用适当的 grid/block sizes 将上述内核加入队列。

def add(x: torch.Tensor, y: torch.Tensor):
    # 需要预分配输出。
    output = torch.empty_like(x)
    n_elements = output.numel()
    # 启动网格表示并行运行的内核实例的数量。
    # 可以是 Tuple[int],也可以是 Callable(metaparameters) -> Tuple[int]。
    # 在本case中,使用 1D 网格,其中大小是块的数量:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # NOTE:
    #  - 每个 torch.tensor 对象都会隐式转换为其第一个元素的指针。
    #  - `triton.jit` 函数可以通过启动网格索引来获得可调用的 NPU 内核。
    #  - 不要忘记以keywords的方式传递meta-parameters。
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # 返回 z 的句柄。
    return output

使用上述函数计算两个 torch.tensor 对象的 element-wise sum,并测试其正确性:

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='npu')
y = torch.rand(size, device='npu')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

Out:

tensor([0.8329, 1.0024, 1.3639,  ..., 1.0796, 1.0406, 1.5811], device='npu:0')
tensor([0.8329, 1.0024, 1.3639,  ..., 1.0796, 1.0406, 1.5811], device='npu:0')
The maximum difference between torch and triton is 0.0

“The maximum difference between torch and triton is 0.0” 表示Triton和PyTorch的输出结果一致。