类型提升语义#
本文档描述了 JAX 的类型提升规则——即,每对类型 jax.numpy.promote_types() 的结果。关于设计以下描述的类型提升语义的考虑背景,请参见 JAX 类型提升语义的设计。
JAX 的类型提升行为是通过以下类型提升格来确定的:
例如,在以下情况下:
b1表示np.bool_,i2表示np.int16,u4表示np.uint32,bf表示np.bfloat16,f2表示np.float16,c8表示np.complex64,i*表示 Pythonint或弱类型int,f*表示 Pythonfloat或弱类型float,以及c*表示 Pythoncomplex或弱类型complex。
(关于弱类型的更多信息,请参见下面的 弱类型)。
任意两种类型之间的提升由它们在这个格上的 连接 决定,这生成了以下二进制提升表:
| b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
| u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u1 | f* | c* |
| u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u2 | f* | c* |
| u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u4 | f* | c* |
| u8 | u8 | u8 | u8 | u8 | u8 | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | u8 | f* | c* |
| i1 | i1 | i2 | i4 | i8 | f* | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i1 | f* | c* |
| i2 | i2 | i2 | i4 | i8 | f* | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i2 | f* | c* |
| i4 | i4 | i4 | i4 | i8 | f* | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i4 | f* | c* |
| i8 | i8 | i8 | i8 | i8 | f* | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i8 | f* | c* |
| bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c8 | c16 | bf | bf | c8 |
| f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c8 | c16 | f2 | f2 | c8 |
| f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c8 | c16 | f4 | f4 | c8 |
| f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c16 | c16 | f8 | f8 | c16 |
| c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c16 | c8 | c16 | c8 | c8 | c8 |
| c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 |
| i* | i* | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
| f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | f* | f* | c* |
| c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c8 | c8 | c8 | c16 | c8 | c16 | c* | c* | c* |
Jax 的类型提升规则与 NumPy 的规则不同,如上表中绿色背景的单元格所示,这些差异由 numpy.promote_types() 给出。主要差异分为三类:
当将一个弱类型值提升为同一类别的类型化 JAX 值时,JAX 总是倾向于 JAX 值的精度。例如,
jnp.int16(1) + 1将返回int16,而不是像在 NumPy 中那样提升为int64。请注意,这仅适用于 Python 标量值;如果常量是 NumPy 数组,则使用上述类型提升规则。例如,jnp.int16(1) + np.array(1)将返回int64。当将整数或布尔类型与浮点数或复数类型进行比较时,JAX 总是倾向于浮点数或复数类型。
JAX 支持 bfloat16 非标准的 16 位浮点类型 (
jax.numpy.bfloat16),这对于神经网络训练非常有用。唯一值得注意的提升行为是关于 IEEE-754float16,其中bfloat16提升为float32。
NumPy 和 JAX 之间的差异源于这样一个事实:加速器设备,如 GPU 和 TPU,要么在使用 64 位浮点类型(GPU)时会付出显著的性能代价,要么根本不支持 64 位浮点类型(TPU)。经典的 NumPy 的类型提升规则过于倾向于提升到 64 位类型,这对于一个设计为在加速器上运行的系统来说是有问题的。
JAX 使用更适合现代加速器设备的浮点提升规则,并且对浮点类型的提升不那么激进。JAX 用于浮点类型的提升规则与 PyTorch 使用的规则相似。
Python 操作符分派的影响#
请记住,像 + 这样的 Python 运算符会根据被加的两个值的 Python 类型进行调度。这意味着,例如,np.int16(1) + 1 将使用 NumPy 规则进行提升,而 jnp.int16(1) + 1 将使用 JAX 规则进行提升。当这两种提升类型结合时,可能会导致潜在的令人困惑的非结合性提升语义;例如 np.int16(1) + 1 + jnp.int16(1)。
JAX 中的弱类型值#
弱类型 值在 JAX 中大多数情况下可以被认为是具有与 Python 标量(如以下整数标量 2)等效的提升行为。
>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)
JAX 的弱类型框架旨在防止 JAX 值与没有明确用户指定类型的值(如 Python 标量字面量)之间的二元操作中出现不必要的类型提升。例如,如果 2 不被视为弱类型,上述表达式将导致隐式类型提升:
>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)
在JAX中使用时,Python标量有时会被提升为 DeviceArray 对象,例如在JIT编译期间。为了在这种情况下保持所需的提升语义, DeviceArray 对象携带一个 weak_type 标志,该标志可以在数组的字符串表示中看到:
>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)
如果 dtype 被显式指定,它将生成一个标准的强类型数组值:
>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)
严格的 dtype 提升#
在某些情况下,禁用隐式类型提升行为并要求所有提升都显式进行可能会有用。这可以通过将 jax_numpy_dtype_promotion 标志设置为 'strict' 来在 JAX 中实现。在局部,可以使用上下文管理器来实现:
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
为了方便起见,严格推广模式仍然允许安全的弱类型推广,因此您仍然可以编写混合 JAX 数组和 Python 标量的代码:
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + 1
>>> print(z)
2.0
如果你更倾向于全局设置配置,你可以使用标准的配置更新来实现:
jax.config.update('jax_numpy_dtype_promotion', 'strict')
要恢复默认的标准类型提升,请将此配置设置为 'standard':
jax.config.update('jax_numpy_dtype_promotion', 'standard')