Shortcuts

torch.cuda.jiterator._create_jit_fn

torch.cuda.jiterator._create_jit_fn(code_string, **kwargs)[源代码]

创建一个由jiterator生成的cuda内核用于元素操作。

代码字符串必须是一个有效的CUDA函数,用于描述单个元素的计算。代码字符串必须遵循c++模板模式,如下例所示。此函数将被内联到逐元素内核模板中,并动态编译。编译后的内核将缓存在内存和本地临时目录中。

Jiterator 生成的内核接受非连续张量,并支持广播和类型提升。

Parameters
  • code_string (str) – 要由jiterator编译的CUDA代码字符串。入口函数必须按值返回。

  • kwargs (字典, 可选) – 生成函数的键值参数

Return type

可调用

示例:

code_string = "template  T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# 像调用普通python函数一样调用jitted函数
result = jitted_fn(a, b, alpha=3.14)

code_string 也允许定义多个函数,最后一个函数将被视为入口函数。

示例:

code_string = "template  T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
code_string += "template  T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
jitted_fn = create_jit_fn(code_string, val=0.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# 像调用普通python函数一样调用jitted函数
result = jitted_fn(a, b)  # 使用默认值val=0.0

Jiterator 可以与 Python 注册一起使用,以覆盖操作符的 CUDA 内核。 以下示例是使用 relu 覆盖 gelu 的 CUDA 内核。

示例:

code_string = "template  T my_gelu(T a) { return a > 0 ? a : 0; }"
my_gelu = create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', my_gelu, "CUDA")
# torch.nn.GELU 和 torch.nn.function.gelu 现在已被覆盖
a = torch.rand(3, device='cuda')
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))

警告

此API目前处于测试阶段,可能会在未来的版本中进行更改。

警告

此API仅支持最多8个输入和1个输出

警告

所有输入张量必须位于CUDA设备上