jax.numpy.apply_over_axes

jax.numpy.apply_over_axes#

jax.numpy.apply_over_axes(func, a, axes)[源代码][源代码]#

在多个轴上重复应用函数。

LAX-backend 对 numpy.apply_over_axes() 的实现。

原始文档字符串如下。

func 被调用为 res = func(a, axis),其中 axisaxes 的第一个元素。函数调用的结果 res 必须与 a 具有相同的维度,或者少一个维度。如果 resa 少一个维度,则在 axis 之前插入一个维度。然后,对 axes 中的每个轴重复调用 func,并将 res 作为第一个参数。

参数:
  • func (function) – 此函数必须接受两个参数,func(a, axis)

  • a (array_like) – 输入数组。

  • axes (array_like) – func 应用的轴;元素必须是整数。

返回:

apply_over_axis – 输出数组。维度数量与 a 相同,但形状可能不同。这取决于 func 是否改变了其输出相对于输入的形状。

返回类型:

ndarray