矩阵乘法 (Matrix Multiplication)
在本节中,我们展示了使用 Triton 进行矩阵乘法的内核实现。
计算内核
import pytest
import torch
import torch_npu
import triton
import triton.language as tl
@triton.jit
def triton_dot_2_None(output_ptr, x_ptr, y_ptr, z_ptr,A : tl.constexpr,B : tl.constexpr,C : tl.constexpr,D : tl.constexpr):
aidx=tl.arange(0,A)
bidx=tl.arange(0,B)
cidx=tl.arange(0,C)
didx=tl.arange(0,D)
accumulator = tl.zeros((B, D), dtype=tl.float32)
Xidx=bidx[:,None]*C+cidx[None,:]
Yidx=cidx[:,None]*D+didx[None,:]
Zidx=bidx[:,None]*D+didx[None,:]
X = tl.load(x_ptr+Xidx)
Y = tl.load(y_ptr+Yidx)
Z = tl.load(z_ptr+Zidx)
tl.device_print("X: ", X)
ret = tl.dot(X, Y)
oidx=bidx[:,None]*D+didx[None,:]
tl.store(output_ptr+oidx,ret)
工具方法
def torch_dot_None(x0, x1):
res = torch.matmul(x0, x1)
return res
def get_torch_typename(dtype):
if dtype == 'float32':
tyname = torch.float32
elif dtype == 'int32':
tyname = torch.int32
elif dtype == 'int64':
tyname = torch.int64
elif dtype == 'float16':
tyname = torch.float16
elif dtype == 'int16':
tyname = torch.int16
elif dtype == 'int8':
tyname = torch.int8
elif dtype == 'bool':
tyname = torch.bool
elif dtype == 'bfloat16':
tyname = torch.bfloat16
else:
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
return tyname
def generate_tensor(shape, dtype):
if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16':
return torch.randn(size=shape, dtype=eval('torch.' + dtype))
elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16':
return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype))
elif dtype == 'int8':
return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype))
elif dtype == 'bool':
return torch.randint(low=0, high=2, size=shape).bool()
else:
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
def validate_cmp(dtype, y_cal, y_ref):
y_cal=y_cal.npu()
y_ref=y_ref.npu()
if dtype == 'float16':
torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True)
elif dtype == 'bfloat16':
torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, equal_nan=True)
elif dtype == 'float32':
torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True)
elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8':
assert torch.equal(y_cal, y_ref)
elif dtype == 'bool':
assert torch.equal(y_cal, y_ref)
else:
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
参数化测试
testlist = [
(3, 16, 16, 16),
]
typelist = ['float16',]
@pytest.mark.parametrize('A, B, C, D',testlist)
@pytest.mark.parametrize('sigtype',typelist)
def test_dot_2_None(sigtype, A, B, C, D):
dtype = get_torch_typename(sigtype)
x0 = generate_tensor(shape = (B, C),dtype = sigtype).npu()
x1 = generate_tensor(shape = (C, D),dtype = sigtype).npu()
if 'int' in sigtype:
x2 = generate_tensor(shape = (B, D),dtype = 'int32').npu()
ans = torch_dot_None(x0.to(torch.float32), x1.to(torch.float32)).to(dtype)
else:
x2 = generate_tensor(shape = (B, D),dtype = 'float32').npu()
ans = torch_dot_None(x0, x1)
output = torch.zeros((B, D), dtype = dtype).npu()
triton_dot_2_None[1,1,1](output, x0, x1, x2, A, B, C, D, debug = True)
validate_cmp(sigtype,output,ans)
print(f"Test matmul with dtype={sigtype}, shape=({A},{B},{C},{D}) PASSED!")
if __name__ == "__main__":
test_dot_2_None("float16", 3, 16, 16, 16)
Out:
Test matmul with dtype=float16, shape=(3,16,16,16) PASSED!
上面输出日志表明Triton和Pytorch上的输出结果完全一致。