jax.numpy.union1d#
- jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[源代码][源代码]#
计算两个一维数组的并集。
JAX 实现的
numpy.union1d()。由于
union1d的输出大小取决于数据,该函数通常与jit()和其他 JAX 变换不兼容。JAX 版本添加了可选的size参数,必须在jnp.union1d用于此类上下文中时静态指定该参数。- 参数:
ar1 (ArrayLike) – 要联合的第一个元素数组。
ar2 (ArrayLike) – 第二个要合并的元素数组
size (int | None) – 如果指定,则仅返回前
size个排序后的元素。如果元素数量少于size所指示的数量,返回值将用fill_value填充。fill_value (ArrayLike | None) – 当指定
size并且元素数量少于指定数量时,用fill_value填充剩余的条目。默认为最小值。
- 返回:
包含输入数组中所有元素的并集的数组。
- 返回类型:
参见
jax.numpy.intersect1d(): 两个一维数组的集合交集。jax.numpy.setxor1d(): 两个一维数组的集合异或。jax.numpy.setdiff1d(): 两个一维数组的集合差。
示例
计算两个数组的并集:
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.union1d(ar1, ar2) Array([1, 2, 3, 4, 5, 6], dtype=int32)
因为输出形状是动态的,这将在
jit()和其他变换下失败:>>> jax.jit(jnp.union1d)(ar1, ar2) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
为了确保静态已知的输出形状,你可以传递一个静态的
size参数:>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size']) >>> jit_union1d(ar1, ar2, size=6) Array([1, 2, 3, 4, 5, 6], dtype=int32)
如果
size太小,联合体将被截断:>>> jit_union1d(ar1, ar2, size=4) Array([1, 2, 3, 4], dtype=int32)
如果
size过大,则输出会用fill_value填充:>>> jit_union1d(ar1, ar2, size=8, fill_value=0) Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)