jax.numpy.poly

目录

jax.numpy.poly#

jax.numpy.poly(seq_of_zeros)[源代码][源代码]#

返回给定根序列的多项式的系数。

JAX 实现的 numpy.poly()

参数:

seq_of_zeros (ArrayLike) – 多项式的根的标量或数组,形状为 (M,)(M, M)

返回:

包含多项式系数的数组。输出的 dtype 总是提升为不精确类型。

返回类型:

Array

备注

jax.numpy.poly()numpy.poly() 不同:

  • 当输入是标量时,np.poly 会引发 TypeError,而 jnp.poly 将标量视为长度为1的数组。

  • 对于复数或方形输入,jnp.poly 总是返回复数系数,而 np.poly 可能根据其值返回实数或复数。

参见

示例

标量输入:

>>> jnp.poly(1)
Array([ 1., -1.], dtype=float32)

包含整数值的输入数组:

>>> x = jnp.array([1, 2, 3])
>>> jnp.poly(x)
Array([ 1., -6., 11., -6.], dtype=float32)

输入带有复共轭的数组:

>>> x = jnp.array([2, 1+2j, 1-2j])
>>> jnp.poly(x)
Array([  1.+0.j,  -4.+0.j,   9.+0.j, -10.+0.j], dtype=complex64)

输入数组作为具有实值输入的方阵:

>>> x = jnp.array([[2, 1, 5],
...                [3, 4, 7],
...                [1, 3, 5]])
>>> jnp.round(jnp.poly(x))
Array([  1.+0.j, -11.-0.j,   9.+0.j, -15.+0.j], dtype=complex64)