jax.numpy.gradient#
- jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[源代码][源代码]#
返回一个 N 维数组的梯度。
LAX-backend 实现的
numpy.gradient()
。原始文档字符串如下。
梯度是使用在内部点上的二阶精确中心差分计算的,并且在边界处使用一阶或二阶精确的一侧(前向或后向)差分。因此,返回的梯度与输入数组具有相同的形状。
- 参数:
f (array_like) – 一个包含标量函数样本的N维数组。
varargs (list of scalar or array, optional) – f 值之间的间距。默认情况下,所有维度的间距为单位间距。间距可以通过以下方式指定:1. 单个标量来指定所有维度的采样距离。2. N 个标量来指定每个维度的恒定采样距离。即 dx, dy, dz, … 3. N 个数组来指定 F 的每个维度上值的坐标。数组的长度必须与相应维度的大小匹配。4. 任何 N 个标量/数组的组合,具有 2 和 3 的含义。如果给出了 axis,则 varargs 的数量必须等于轴的数量。默认值:1。(见下面的示例)。
axis (None or int or tuple of ints, optional) – 梯度仅沿给定的轴或轴计算。默认(axis = None)是计算输入数组所有轴的梯度。轴可以是负数,在这种情况下,它从最后一个轴计数到第一个轴。
edge_order (int | None)
- 返回:
梯度 – 一个由 ndarray 组成的元组(如果只有一个维度,则为单个 ndarray),对应于 f 对每个维度的导数。每个导数与 f 具有相同的形状。
- 返回类型:
ndarray or tuple of ndarray
引用