jax.scipy.stats.mode#
- jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[源代码][源代码]#
计算数组沿轴的模式(最常见的值)。
JAX 实现的
scipy.stats.mode()。- 参数:
- 返回:
一个数组的元组,
(mode, count)。mode是众数值的数组,count是每个值在输入数组中出现的次数。- 返回类型:
ModeResult
示例
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) >>> mode, count = jax.scipy.stats.mode(x) >>> mode, count (Array(4, dtype=int32), Array(3, dtype=int32))
对于多维数组,
jax.scipy.stats.mode计算沿axis=0的mode及其对应的count:>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) >>> mode, count = jax.scipy.stats.mode(x1) >>> mode, count (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
如果
axis=1,mode和count将沿着axis 1计算。>>> mode, count = jax.scipy.stats.mode(x1, axis=1) >>> mode, count (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
默认情况下,
jax.scipy.stats.mode会减少结果的维度。要使结果的维度与输入数组相同,必须将参数keepdims设置为True。>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) >>> mode, count (Array([[1], [3], [2]], dtype=int32), Array([[3], [3], [3]], dtype=int32))