# 层标准化 (Layer Normalization) 在本节中,我们将使用 Triton 编写一个比 PyTorch 实现运行更快的高性能层标准化 (layer normalization) 内核。 ## 计算内核 ```Python import pytest import torch import triton import triton.language as tl import torch_npu @triton.jit def _layer_norm_fwd_fused( X, # 输入指针 Y, # 输出指针 W, # 权重指针 B, # 偏差指针 Mean, # 均值指针 Rstd, # 1/std 指针 stride, # 指针移动一行应该增加多少元素 N, # X 的列数 eps, # 用于避免除以 0 的 epsilon BLOCK_SIZE: tl.constexpr, ): # 映射程序 id 到对应计算的 X 和 Y 的行 row = tl.program_id(0) Y += row * stride X += row * stride # 计算均值 mean = 0 _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # 计算方差 _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) x = tl.where(cols < N, x - mean, 0.) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) # 写入 mean / rstd tl.store(Mean + row, mean) tl.store(Rstd + row, rstd) # 归一化并应用线性变换 for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # 写入输出 tl.store(Y + cols, y, mask=mask) ``` 使用 Triton 自定义的 LayerNorm 实现方式 ```Python @torch.inference_mode() def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): # 分配与输入相同形状和数据类型的输出张量 y = torch.empty_like(x) # 将输入 x 展平成二维形状 [-1, feature_dim] 以便处理最后一个维度 x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape mean = torch.empty((M, ), dtype=torch.float32, device=x.device) rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) BLOCK_SIZE = 1024 # enqueue kernel kernel = _layer_norm_fwd_fused[(M, )]( # M 表示 block 数目,launch grid=(M,) x_arg, y, weight, bias, mean, rstd, # 输入输出及中间量 x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE) # 返回归一化后的输出结果 return y # 前向传播时调用层归一化 def _layer_norm(M, N, dtype, eps=1e-5, device='npu'): # 构造数据 x_shape = (M, N) w_shape = (x_shape[-1], ) weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) dy = .1 * torch.randn_like(x) x.requires_grad_(True) # 前向传播 y_tri = layer_norm(x, w_shape, weight, bias, eps) y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) # 判断是否近似 assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) print(f"y_tri: {y_tri}") print(f"y_ref: {y_ref}") print(f"Layer Normalization {M},{N} {dtype} PASSED!") # 执行测试 if __name__ == '__main__': _layer_norm(128, 128, torch.float16) _layer_norm(128, 128, torch.bfloat16) _layer_norm(128, 128, torch.float32) ``` 结果 ```bash y_tri: tensor([[ 0.2512, 0.0647, 0.8389, ..., 2.3652, 1.5039, 1.1904], [ 1.0908, 1.5391, 0.2269, ..., 1.6846, 1.0996, 0.9614], [-0.2974, 0.5918, 0.3225, ..., 2.2891, -0.8418, 0.6885], ..., [ 0.5225, -0.0068, 0.4968, ..., -1.1221, 1.7422, 0.6143], [ 0.4463, 1.2441, 0.2224, ..., 2.2969, -0.3311, 0.6177], [-0.0113, 0.8423, 0.3696, ..., 1.3838, 1.2471, 0.8750]], device='npu:0', dtype=torch.float16) y_ref: tensor([[ 0.2512, 0.0647, 0.8389, ..., 2.3652, 1.5039, 1.1904], [ 1.0908, 1.5391, 0.2269, ..., 1.6846, 1.0996, 0.9614], [-0.2974, 0.5918, 0.3225, ..., 2.2891, -0.8418, 0.6885], ..., [ 0.5225, -0.0068, 0.4968, ..., -1.1221, 1.7422, 0.6143], [ 0.4463, 1.2441, 0.2224, ..., 2.2969, -0.3311, 0.6177], [-0.0113, 0.8423, 0.3696, ..., 1.3838, 1.2471, 0.8750]], device='npu:0', dtype=torch.float16, grad_fn=) Layer Normalization 128,128 torch.float16 PASSED! y_tri: tensor([[-0.4180, 0.9648, 0.8633, ..., 0.7656, 0.8438, 0.3633], [ 0.4453, 0.5352, 0.9102, ..., 1.1875, -0.0562, 0.5391], [ 1.3125, 0.9961, 0.9219, ..., 0.9688, 0.0025, 0.5156], ..., [-0.1426, 0.6289, 0.9609, ..., 0.9648, -0.1260, -0.1270], [ 1.1641, 0.6680, 0.8281, ..., 0.9258, 0.9062, 0.1768], [-0.2129, 0.7109, 0.9141, ..., 0.7891, -0.0767, 0.5156]], device='npu:0', dtype=torch.bfloat16) y_ref: tensor([[-0.4180, 0.9648, 0.8633, ..., 0.7656, 0.8438, 0.3633], [ 0.4453, 0.5352, 0.9102, ..., 1.1875, -0.0562, 0.5391], [ 1.3125, 0.9961, 0.9219, ..., 0.9688, 0.0025, 0.5156], ..., [-0.1426, 0.6289, 0.9609, ..., 0.9648, -0.1260, -0.1270], [ 1.1641, 0.6680, 0.8281, ..., 0.9258, 0.9062, 0.1768], [-0.2129, 0.7109, 0.9141, ..., 0.7891, -0.0767, 0.5156]], device='npu:0', dtype=torch.bfloat16, grad_fn=) Layer Normalization 128,128 torch.bfloat16 PASSED! y_tri: tensor([[-0.2980, 0.2922, 0.6481, ..., 0.9786, 0.7304, 0.8982], [ 1.5911, 0.0474, 0.6518, ..., 0.8013, 0.2435, 1.3748], [ 1.3024, 0.6265, 0.6473, ..., 0.8423, 0.0984, -1.1839], ..., [-0.2195, 0.1359, 0.6461, ..., 0.8319, 1.0899, 1.5015], [ 0.6371, 0.3687, 0.6530, ..., 0.9359, 0.0818, 0.6499], [ 0.1178, 0.3639, 0.6475, ..., 0.7221, 0.4622, 1.4510]], device='npu:0') y_ref: tensor([[-0.2980, 0.2922, 0.6481, ..., 0.9786, 0.7304, 0.8982], [ 1.5911, 0.0474, 0.6518, ..., 0.8013, 0.2435, 1.3748], [ 1.3024, 0.6265, 0.6473, ..., 0.8423, 0.0984, -1.1839], ..., [-0.2195, 0.1359, 0.6461, ..., 0.8319, 1.0899, 1.5015], [ 0.6371, 0.3687, 0.6530, ..., 0.9359, 0.0818, 0.6499], [ 0.1178, 0.3639, 0.6475, ..., 0.7221, 0.4622, 1.4510]], device='npu:0', grad_fn=) Layer Normalization 128,128 torch.float32 PASSED! ``` "Layer Normalization 128,128 torch.float16 PASSED!、\ Layer Normalization 128,128 torch.bfloat16 PASSED! \ Layer Normalization 128,128 torch.float32 PASSED!" 表明Triton和PyTorch上float16、bfloat16、float32数据类型的输出结果完全一致。