jax.ShapeDtypeStruct#
- class jax.ShapeDtypeStruct(shape, dtype, named_shape=None, sharding=None, weak_type=False)[源代码][源代码]#
用于存储数组的形状、数据类型和其他静态属性的容器。
ShapeDtypeStruct通常与jax.eval_shape()一起使用。- 参数:
shape – 一个表示数组形状的整数序列
dtype – 一个类似数据类型的对象
sharding – (可选) 一个
jax.Sharding对象
方法
__init__(shape, dtype[, named_shape, ...])属性
shapedtypeshardingweak_typelayoutnamed_shapendimsize