jax.numpy.poly#
- jax.numpy.poly(seq_of_zeros)[源代码][源代码]#
返回给定根序列的多项式的系数。
JAX 实现的
numpy.poly()
。- 参数:
seq_of_zeros (ArrayLike) – 多项式的根的标量或数组,形状为
(M,)
或(M, M)
。- 返回:
包含多项式系数的数组。输出的 dtype 总是提升为不精确类型。
- 返回类型:
备注
jax.numpy.poly()
与numpy.poly()
不同:当输入是标量时,
np.poly
会引发TypeError
,而jnp.poly
将标量视为长度为1的数组。对于复数或方形输入,
jnp.poly
总是返回复数系数,而np.poly
可能根据其值返回实数或复数。
参见
jax.numpy.polyfit()
: 最小二乘多项式拟合。jax.numpy.polyval()
: 在特定值处计算多项式。jax.numpy.roots()
: 计算给定系数的多项式的根。
示例
标量输入:
>>> jnp.poly(1) Array([ 1., -1.], dtype=float32)
包含整数值的输入数组:
>>> x = jnp.array([1, 2, 3]) >>> jnp.poly(x) Array([ 1., -6., 11., -6.], dtype=float32)
输入带有复共轭的数组:
>>> x = jnp.array([2, 1+2j, 1-2j]) >>> jnp.poly(x) Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64)
输入数组作为具有实值输入的方阵:
>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.round(jnp.poly(x)) Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)