层标准化 (Layer Normalization)

在本节中,我们将使用 Triton 编写一个比 PyTorch 实现运行更快的高性能层标准化 (layer normalization) 内核。

计算内核

import pytest
import torch
import triton
import triton.language as tl
import torch_npu

@triton.jit
def _layer_norm_fwd_fused(
    X,  # 输入指针
    Y,  # 输出指针
    W,  # 权重指针
    B,  # 偏差指针
    Mean,  # 均值指针
    Rstd,  # 1/std 指针
    stride,  # 指针移动一行应该增加多少元素
    N,  # X 的列数
    eps,  # 用于避免除以 0 的 epsilon
    BLOCK_SIZE: tl.constexpr,
):
    # 映射程序 id 到对应计算的 X 和 Y 的行
    row = tl.program_id(0)
    Y += row * stride
    X += row * stride
    # 计算均值
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        _mean += a
    mean = tl.sum(_mean, axis=0) / N
    # 计算方差
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        x = tl.where(cols < N, x - mean, 0.)
        _var += x * x
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    # 写入 mean / rstd
    tl.store(Mean + row, mean)
    tl.store(Rstd + row, rstd)
    # 归一化并应用线性变换
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < N
        w = tl.load(W + cols, mask=mask)
        b = tl.load(B + cols, mask=mask)
        x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
        x_hat = (x - mean) * rstd
        y = x_hat * w + b
        # 写入输出
        tl.store(Y + cols, y, mask=mask)

使用 Triton 自定义的 LayerNorm 实现方式

@torch.inference_mode()
def layer_norm(x, normalized_shape, weight, bias, eps=1e-5):
    # 分配与输入相同形状和数据类型的输出张量
    y = torch.empty_like(x)
    
    # 将输入 x 展平成二维形状 [-1, feature_dim] 以便处理最后一个维度
    x_arg = x.reshape(-1, x.shape[-1])
    M, N = x_arg.shape
    
    mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
    rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)

    BLOCK_SIZE = 1024

    # enqueue kernel
    kernel = _layer_norm_fwd_fused[(M, )](  # M 表示 block 数目,launch grid=(M,)
        x_arg, y, weight, bias, mean, rstd,  # 输入输出及中间量
        x_arg.stride(0), N, eps,
        BLOCK_SIZE=BLOCK_SIZE)
    # 返回归一化后的输出结果    
    return y

# 前向传播时调用层归一化
def _layer_norm(M, N, dtype, eps=1e-5, device='npu'):
    # 构造数据
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
    dy = .1 * torch.randn_like(x)
    x.requires_grad_(True)
    # 前向传播
    y_tri = layer_norm(x, w_shape, weight, bias, eps)
    y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
    # 判断是否近似
    assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
    print(f"y_tri: {y_tri}")
    print(f"y_ref: {y_ref}")
    print(f"Layer Normalization {M},{N} {dtype} PASSED!")

# 执行测试
if __name__ == '__main__':
    _layer_norm(128, 128, torch.float16)
    _layer_norm(128, 128, torch.bfloat16)
    _layer_norm(128, 128, torch.float32)

结果

y_tri: tensor([[ 0.2512,  0.0647,  0.8389,  ...,  2.3652,  1.5039,  1.1904],
        [ 1.0908,  1.5391,  0.2269,  ...,  1.6846,  1.0996,  0.9614],
        [-0.2974,  0.5918,  0.3225,  ...,  2.2891, -0.8418,  0.6885],
        ...,
        [ 0.5225, -0.0068,  0.4968,  ..., -1.1221,  1.7422,  0.6143],
        [ 0.4463,  1.2441,  0.2224,  ...,  2.2969, -0.3311,  0.6177],
        [-0.0113,  0.8423,  0.3696,  ...,  1.3838,  1.2471,  0.8750]],
       device='npu:0', dtype=torch.float16)
y_ref: tensor([[ 0.2512,  0.0647,  0.8389,  ...,  2.3652,  1.5039,  1.1904],
        [ 1.0908,  1.5391,  0.2269,  ...,  1.6846,  1.0996,  0.9614],
        [-0.2974,  0.5918,  0.3225,  ...,  2.2891, -0.8418,  0.6885],
        ...,
        [ 0.5225, -0.0068,  0.4968,  ..., -1.1221,  1.7422,  0.6143],
        [ 0.4463,  1.2441,  0.2224,  ...,  2.2969, -0.3311,  0.6177],
        [-0.0113,  0.8423,  0.3696,  ...,  1.3838,  1.2471,  0.8750]],
       device='npu:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>)
Layer Normalization 128,128 torch.float16 PASSED!
y_tri: tensor([[-0.4180,  0.9648,  0.8633,  ...,  0.7656,  0.8438,  0.3633],
        [ 0.4453,  0.5352,  0.9102,  ...,  1.1875, -0.0562,  0.5391],
        [ 1.3125,  0.9961,  0.9219,  ...,  0.9688,  0.0025,  0.5156],
        ...,
        [-0.1426,  0.6289,  0.9609,  ...,  0.9648, -0.1260, -0.1270],
        [ 1.1641,  0.6680,  0.8281,  ...,  0.9258,  0.9062,  0.1768],
        [-0.2129,  0.7109,  0.9141,  ...,  0.7891, -0.0767,  0.5156]],
       device='npu:0', dtype=torch.bfloat16)
y_ref: tensor([[-0.4180,  0.9648,  0.8633,  ...,  0.7656,  0.8438,  0.3633],
        [ 0.4453,  0.5352,  0.9102,  ...,  1.1875, -0.0562,  0.5391],
        [ 1.3125,  0.9961,  0.9219,  ...,  0.9688,  0.0025,  0.5156],
        ...,
        [-0.1426,  0.6289,  0.9609,  ...,  0.9648, -0.1260, -0.1270],
        [ 1.1641,  0.6680,  0.8281,  ...,  0.9258,  0.9062,  0.1768],
        [-0.2129,  0.7109,  0.9141,  ...,  0.7891, -0.0767,  0.5156]],
       device='npu:0', dtype=torch.bfloat16, grad_fn=<NativeLayerNormBackward0>)
Layer Normalization 128,128 torch.bfloat16 PASSED!
y_tri: tensor([[-0.2980,  0.2922,  0.6481,  ...,  0.9786,  0.7304,  0.8982],
        [ 1.5911,  0.0474,  0.6518,  ...,  0.8013,  0.2435,  1.3748],
        [ 1.3024,  0.6265,  0.6473,  ...,  0.8423,  0.0984, -1.1839],
        ...,
        [-0.2195,  0.1359,  0.6461,  ...,  0.8319,  1.0899,  1.5015],
        [ 0.6371,  0.3687,  0.6530,  ...,  0.9359,  0.0818,  0.6499],
        [ 0.1178,  0.3639,  0.6475,  ...,  0.7221,  0.4622,  1.4510]],
       device='npu:0')
y_ref: tensor([[-0.2980,  0.2922,  0.6481,  ...,  0.9786,  0.7304,  0.8982],
        [ 1.5911,  0.0474,  0.6518,  ...,  0.8013,  0.2435,  1.3748],
        [ 1.3024,  0.6265,  0.6473,  ...,  0.8423,  0.0984, -1.1839],
        ...,
        [-0.2195,  0.1359,  0.6461,  ...,  0.8319,  1.0899,  1.5015],
        [ 0.6371,  0.3687,  0.6530,  ...,  0.9359,  0.0818,  0.6499],
        [ 0.1178,  0.3639,  0.6475,  ...,  0.7221,  0.4622,  1.4510]],
       device='npu:0', grad_fn=<NativeLayerNormBackward0>)
Layer Normalization 128,128 torch.float32 PASSED!

“Layer Normalization 128,128 torch.float16 PASSED!、
Layer Normalization 128,128 torch.bfloat16 PASSED!
Layer Normalization 128,128 torch.float32 PASSED!” 表明Triton和PyTorch上float16、bfloat16、float32数据类型的输出结果完全一致。