triton.language.cumprod

triton.language.cumprod = JITFunction(triton.language.standard:cumprod)

Returns the cumprod of all elements in the input tensor along the provided axis

Parameters:
  • input (Tensor) – the input values

  • axis (int) – the dimension along which the scan should be done

This function can also be called as a member function on tensor, as x.cumprod(...) instead of cumprod(x, ...).