jax.numpy.polydiv

目录

jax.numpy.polydiv#

jax.numpy.polydiv(u, v, *, trim_leading_zeros=False)[源代码][源代码]#

返回多项式除法的商和余数。

JAX 实现的 numpy.polydiv()

参数:
  • u (ArrayLike) – 股息多项式系数的数组。

  • v (ArrayLike) – 除数多项式系数的数组。

  • trim_leading_zeros (bool) – 默认值为 False。如果为 True,则在返回值中去除前导零以匹配 numpy 的结果。但这会阻止函数在编译代码中使用。由于浮点算术误差累积的差异,值被视为零的截止点可能导致 NumPy 和 JAX 之间,甚至不同 JAX 后端之间的结果不一致。结果可能导致当 trim_leading_zeros=True 时输出形状不一致。

返回:

商和余数数组的元组。输出的 dtype 总是提升为不精确类型。

返回类型:

tuple[Array, Array]

备注

jax.numpy.polydiv() 只接受数组作为输入,不像 numpy.polydiv() 那样也接受标量输入。

参见

示例

>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25  , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25  , 1.4375], dtype=float32), Array([0.    , 0.    , 7.5625], dtype=float32))

如果 trim_leading_zeros=True ,结果与 np.polydiv 的结果匹配。

>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25  , 1.4375], dtype=float32), Array([7.5625], dtype=float32))