jax.eval_shape#
- jax.eval_shape(fun, *args, **kwargs)[源代码][源代码]#
在不进行任何浮点运算的情况下计算
fun的形状/数据类型。这个实用函数对于执行形状推断很有用。其输入/输出行为由以下定义:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out)
但不是直接应用
fun,这可能会很昂贵,它使用 JAX 的抽象解释机制来评估形状,而不进行任何浮点运算。使用
eval_shape()也可以捕捉形状错误,并且会引发与评估fun(*args, **kwargs)相同的形状错误。- 参数:
fun (Callable) – 应评估其输出形状的函数。
*args – 一个数组、标量或(嵌套的)标准 Python 容器(元组、列表、字典、命名元组,即 pytrees)的位置参数元组。由于只访问
shape和dtype属性,因此可以使用jax.ShapeDtypeStruct或其他作为 ndarrays 进行鸭子类型化的容器(但请注意,鸭子类型化的对象不能是命名元组,因为这些对象被视为标准 Python 容器)。**kwargs – 一个包含数组、标量或(嵌套)标准 Python 容器(pytrees)的字典作为关键字参数。与
args中一样,数组值只需具备shape和dtype属性即可。
- 返回:
一个嵌套的 PyTree,包含
jax.ShapeDtypeStruct对象作为叶子。- 返回类型:
out
例如:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32
通过
eval_shape()传递的所有参数都将被视为动态的;静态参数可以通过闭包包含,例如使用functools.partial():>>> import jax >>> from jax import lax >>> from functools import partial >>> import jax.numpy as jnp >>> >>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32) >>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32) >>> >>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME") >>> out = jax.eval_shape(conv_same, x, kernel) >>> print(out.shape) (1, 32, 28, 28) >>> print(out.dtype) float32