triton 总览
triton op 支持度总览
Triton Op |
int8 |
int16 |
int32 |
uint32 |
int64 |
fp16 |
fp32 |
bf16 |
bool |
|
---|---|---|---|---|---|---|---|---|---|---|
Creation Ops |
arange |
✓ |
✓ |
✓ |
× |
× |
× |
× |
× |
× |
cat |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
full |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
zeros |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
zeros_like |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
cast |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
Shape Manipulation Ops |
broadcast |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
broadcast_to |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
expand_dims |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
interleave |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
join |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
permute |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
ravel |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
reshape |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
split |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
trans |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
view |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
Linear Algebra Ops |
dot |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
dot_scaled |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
Memory/Pointer Ops |
load |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
store |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
make_block_ptr |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
advance |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
Indexing Ops |
flip |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
where |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
swizzle2d |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
|
Math Ops |
add |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
sub |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
mul |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
div |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
floordiv(//) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
|
mod |
✓ |
✓ |
✓ |
× |
× |
× |
× |
× |
× |
|
neg |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
invert(!) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
✓ |
|
and(&) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
✓ |
|
or(|) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
✓ |
|
xor(^) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
✓ |
|
not(~) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
✓ |
|
lshift(<<) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
|
rshift(>>) |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
|
gt |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
ge |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
lt |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
le |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
eq |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
ne |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
logical and |
× |
× |
× |
× |
× |
× |
× |
× |
✓ |
|
logical or |
× |
× |
× |
× |
× |
× |
× |
× |
✓ |
|
abs |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
cdiv |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
|
ceil |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
clamp |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
cos |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
div_rn |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
erf |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
exp |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
exp2 |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
fdiv |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
floor |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
fma |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
log |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
log2 |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
maximum |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
minimum |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
× |
|
rsqrt |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
sigmoid |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
sin |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
softmax |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
sqrt |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
sqrt_rn |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
umulhi |
× |
× |
✓ |
× |
× |
× |
× |
× |
× |
|
Reduction Ops |
argmax |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
argmin |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
max |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
min |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
reduce |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
sum |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
xor_sum |
✓ |
✓ |
✓ |
× |
× |
× |
× |
× |
× |
|
Scan/Sort Ops |
associative_scan |
× |
× |
× |
× |
× |
× |
× |
× |
× |
cumprod |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
cumsum |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
histogram |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
sort |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
gather |
× |
× |
× |
× |
× |
✓ |
✓ |
✓ |
× |
|
Atomic Ops |
atomic_add |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
atomic_and |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
atomic_cas |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
atomic_max |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
atomic_min |
✓ |
✓ |
✓ |
× |
× |
✓ |
✓ |
✓ |
× |
|
atomic_or |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
atomic_xchg |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
atomic_xor |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
Random Number Generation |
randint4x |
× |
× |
✓ |
× |
× |
× |
× |
× |
× |
randint |
× |
× |
✓ |
× |
× |
× |
× |
× |
× |
|
rand |
× |
× |
× |
× |
× |
× |
✓ |
× |
× |
|
randn |
× |
× |
× |
× |
× |
× |
✓ |
× |
× |
|
Iterators |
range |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
static_range |
✓ |
✓ |
✓ |
× |
✓ |
× |
× |
× |
× |
|
Inline Assembly |
inline_asm_elementwise |
× |
× |
× |
× |
× |
× |
× |
× |
× |
Compiler Hint Ops |
debug_barrier |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
max_constancy |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
max_contiguous |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
multiple_of |
× |
× |
× |
× |
× |
× |
× |
× |
× |
|
Debug Ops |
static_print |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
static_assert |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
✓ |
✓ |
|
device_print |
✓ |
✓ |
✓ |
× |
✓ |
✓ |
✓ |
× |
✓ |
|
device_assert |
× |
× |
× |
× |
× |
× |
× |
× |
× |
约束说明
dot: 两个输入A[batch(optional), M, K], B[batch(optional), K, N],M,N按照16对齐,K按照32B对齐。
gather: triton.gather(x, index, axis),假设x的shape为n维度,目前只支持axis=n-1。
permute: triton.permute(x, dims),不支持dims=[2, 1, 0]。
trans: triton.trans(x, dims),不支持dims=[2, 1 , 0]。
device_print: 需要增加2个环境变量,TRITON_DEVICE_PRINT=1,TRITON_ENABLE_TASKQUEUE=0。TRITON_ENABLE_TASKQUEUE=0可能造成程序运行不稳定,建议仅临时使用。
atomic_add: 不支持标量(包括长度为1的tensor)访存
atomic_max: 不支持标量(包括长度为1的tensor)访存
atomic_min: 不支持标量(包括长度为1的tensor)访存
permute: 不支持不相邻轴转置,如
(0, 1, 2) -> (2, 1, 0)
trans: 不支持不相邻轴转置,如
(0, 1, 2) -> (2, 1, 0)
ALL: int8类型由于特殊处理,会占用更大的片上空间,编译时容易造成ub overflow报错,通常调整tilling即可解决