jax.Array#
- class jax.Array#
JAX 的数组基类
jax.Array是用于 JAX 数组和追踪器实例检查和类型注解的公共接口。它的主要应用在于实例检查和类型注解;例如:x = jnp.arange(5) isinstance(x, jax.Array) # returns True both inside and outside traced functions. def f(x: Array) -> Array: # type annotations are valid for traced and non-traced types. return x
jax.Array不应直接用于数组的创建;相反,您应该使用jax.numpy中提供的数组创建例程,例如jax.numpy.array()、jax.numpy.zeros()、jax.numpy.ones()、jax.numpy.full()、jax.numpy.arange()等。- __init__()#
方法
__init__()addressable_data(index)返回特定索引处可寻址数据的数组。
all([axis, out, keepdims, where])测试沿给定轴的所有数组元素是否评估为 True。
any([axis, out, keepdims, where])测试沿给定轴的任何数组元素是否评估为 True。
argmax([axis, out, keepdims])返回最大值的索引。
argmin([axis, out, keepdims])返回最小值的索引。
argpartition(kth[, axis])返回部分排序数组的索引。
argsort([axis, kind, order, stable, descending])返回对数组进行排序的索引。
astype(dtype[, copy, device])复制数组并转换为指定的数据类型。
choose(choices[, out, mode])从多个数组的元素中构建一个数组。
clip([min, max])返回一个数组,其值被限制在指定范围内。
compress(condition[, axis, out, size, ...])返回沿给定轴的此数组的选择切片。
conj()返回数组的复共轭。
返回数组的复共轭。
copy()返回数组的副本。
异步地将
Array复制到主机。cumprod([axis, dtype, out])返回数组的累积乘积。
cumsum([axis, dtype, out])返回数组的累计和。
diagonal([offset, axis1, axis2])返回数组中指定的对角线。
dot(b, *[, precision, preferred_element_type])计算两个数组的点积。
flatten([order])将数组展平为1维形状。
item(*args)将数组的一个元素复制到一个标准的 Python 标量并返回它。
max([axis, out, keepdims, initial, where])返回沿给定轴的数组元素的最大值。
mean([axis, dtype, out, keepdims, where])返回沿给定轴的数组元素的平均值。
min([axis, out, keepdims, initial, where])返回沿给定轴的数组元素的最小值。
nonzero(*[, fill_value, size])返回数组中非零元素的索引。
prod([axis, dtype, out, keepdims, initial, ...])返回数组元素在给定轴上的乘积。
ptp([axis, out, keepdims])返回给定轴上的峰峰值范围。
ravel([order])将数组展平为1维形状。
repeat(repeats[, axis, total_repeat_length])从重复元素构建数组。
reshape(*args[, order])返回一个包含相同数据但形状不同的新数组。
round([decimals, out])将数组元素四舍五入到给定的小数位。
searchsorted(v[, side, sorter, method])在一个已排序的数组中执行二分查找。
sort([axis, kind, order, stable, descending])返回数组的排序副本。
squeeze([axis])从数组中移除一个或多个长度为1的轴。
std([axis, dtype, out, ddof, keepdims, ...])计算沿给定轴的标准差。
sum([axis, dtype, out, keepdims, initial, ...])数组元素在给定轴上的总和。
swapaxes(axis1, axis2)交换数组的两个轴。
take(indices[, axis, out, mode, ...])从数组中提取元素。
to_device(device, *[, stream])返回指定设备上的数组副本
trace([offset, axis1, axis2, dtype, out])返回对角线上的元素之和。
transpose(*args)返回数组的一个转置副本。
var([axis, dtype, out, ddof, keepdims, ...])计算沿指定轴的方差。
view([dtype, type])返回数组的按位复制,视作新的数据类型。
属性
计算全轴数组转置。
可寻址分片列表。
用于索引更新功能的辅助属性。
与数组API兼容的设备属性。
数组的 数据类型 (
numpy.dtype)。使用
flatten()代替。全局分片列表。
返回数组的虚部。
这个数组是完全可寻址的吗?
这个数组是完全复制的吗?
一个数组元素的字节长度。
计算(批量)矩阵转置。
数组元素消耗的总字节数。
数组的维度数量。
返回数组的实部。
数组的形状。
数组的切片。
数组中元素的总数。