jax.numpy.gradient

目录

jax.numpy.gradient#

jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[源代码][源代码]#

返回一个 N 维数组的梯度。

LAX-backend 实现的 numpy.gradient()

原始文档字符串如下。

梯度是使用在内部点上的二阶精确中心差分计算的,并且在边界处使用一阶或二阶精确的一侧(前向或后向)差分。因此,返回的梯度与输入数组具有相同的形状。

参数:
  • f (array_like) – 一个包含标量函数样本的N维数组。

  • varargs (list of scalar or array, optional) – f 值之间的间距。默认情况下,所有维度的间距为单位间距。间距可以通过以下方式指定:1. 单个标量来指定所有维度的采样距离。2. N 个标量来指定每个维度的恒定采样距离。即 dx, dy, dz, … 3. N 个数组来指定 F 的每个维度上值的坐标。数组的长度必须与相应维度的大小匹配。4. 任何 N 个标量/数组的组合,具有 2 和 3 的含义。如果给出了 axis,则 varargs 的数量必须等于轴的数量。默认值:1。(见下面的示例)。

  • axis (None or int or tuple of ints, optional) – 梯度仅沿给定的轴或轴计算。默认(axis = None)是计算输入数组所有轴的梯度。轴可以是负数,在这种情况下,它从最后一个轴计数到第一个轴。

  • edge_order (int | None)

返回:

梯度 – 一个由 ndarray 组成的元组(如果只有一个维度,则为单个 ndarray),对应于 f 对每个维度的导数。每个导数与 f 具有相同的形状。

返回类型:

ndarray or tuple of ndarray

引用