cupy.get_array_module#
- cupy.get_array_module(*args)[源代码][源代码]#
返回参数的数组模块。
此函数用于实现 CPU/GPU 通用代码。如果至少有一个参数是
cupy.ndarray对象,则返回cupy模块。- 参数:
args – 用于确定是使用 NumPy 还是 CuPy 的值。
- 返回:
cupy或numpy根据参数的类型返回。- 返回类型:
module
示例
一个 NumPy/CuPy 的通用函数可以写成如下形式
>>> def softplus(x): ... xp = cupy.get_array_module(x) ... return xp.maximum(0, x) + xp.log1p(xp.exp(-abs(x)))