jax.lax.转换元素类型

jax.lax.转换元素类型#

jax.lax.convert_element_type(operand, new_dtype)[源代码][源代码]#

逐元素类型转换。

封装了 XLA 的 ConvertElementType 操作符,该操作符执行从一种类型到另一种类型的逐元素转换。类似于 C++ 中的 static_cast

参数:
  • operand (ArrayLike) – 要转换的数组或标量值。

  • new_dtype (DTypeLike | dtypes.ExtendedDType) – 表示目标类型的 NumPy dtype。

返回:

一个与 operand 形状相同的数组,逐元素转换为 new_dtype

返回类型:

Array