jax.numpy.full#
- jax.numpy.full(shape, fill_value, dtype=None, *, device=None)[源代码][源代码]#
创建一个充满指定值的数组。
JAX 实现的
numpy.full()。- 参数:
- 返回:
指定形状和数据类型的数组,如果指定,则在指定设备上。
- 返回类型:
示例
>>> jnp.full(4, 2, dtype=float) Array([2., 2., 2., 2.], dtype=float32) >>> jnp.full((2, 3), 0, dtype=bool) Array([[False, False, False], [False, False, False]], dtype=bool)
fill_value 也可以是一个数组,该数组会被广播到指定的形状:
>>> jnp.full((2, 3), fill_value=jnp.arange(3)) Array([[0, 1, 2], [0, 1, 2]], dtype=int32)