triton.language.cast

triton.language.cast(input, dtype: dtype, fp_downcast_rounding: str | None = None, bitcast: bool = False, overflow_mode: str | None = None, _builder=None)

Casts a tensor to the given dtype.

Parameters:
  • dtype (tl.dtype) – The target data type.

  • fp_downcast_rounding (str, optional) – The rounding mode for downcasting floating-point values. This parameter is only used when self is a floating-point tensor and dtype is a floating-point type with a smaller bitwidth. Supported values are "rtne" (round to nearest, ties to even) and "rtz" (round towards zero).

  • bitcast (bool, optional) – If true, the tensor is bitcasted to the given dtype, instead of being numerically casted.

  • overflow_mode (string, optional) – When overflow_mode is not set or is “trunc”, truncation (cut-off) will be used to handle overflow. When overflow_mode is “sautrate”, the maximum value of the data type will be used to handle overflow.

This function can also be called as a member function on tensor, as x.cast(...) instead of cast(x, ...).