torch.cuda.jiterator 的源代码
import re
from typing import Callable, List
import torch
from torch import Tensor
__all__: List[str] = []
class _CodeParser:
def __init__(self, code_string: str):
optional_ws = r"\s*"
required_ws = r"\s+"
template_params = r"(?P\<.+\>)"
return_type = r"(?P\w+)"
function_name = r"(?P\w+)"
function_params = r"(?P\(.+\))"
function_body = r"(?P\{.+\})"
pattern = (
optional_ws
+ "template"
+ optional_ws
+ template_params
+ optional_ws
+ return_type
+ required_ws
+ function_name
+ optional_ws
+ function_params
+ optional_ws
+ function_body
+ optional_ws
)
result = re.match(
pattern, code_string, re.DOTALL
) # DOTALL 用于匹配多行
if result is None:
raise Exception(
f"无法解析代码,请检查正确性:\n {code_string}"
)
self.template_params = result["template_params"]
self.return_type = result["return_type"]
self.function_name = result["function_name"]
self.function_params = result["function_params"]
self.function_body = result["function_body"]
class _JittedFunction:
def __init__(
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
):
self.code_string = code_string
assert (
return_by_ref or num_outputs == 1
), "按值返回仅适用于单个输出。"
self.return_by_ref = return_by_ref
self.num_outputs = num_outputs
parsed_code = _CodeParser(code_string)
self.kernel_name = parsed_code.function_name
self.kwargs_dict = kwargs
self.is_cuda_available = torch.cuda.is_available()
def __call__(self, *tensors: Tensor, **kwargs):
# Jiterator 遵循 torch.cuda 的延迟初始化行为
# 推迟检查 cuda 的可用性到函数调用时
assert (
self.is_cuda_available
), "Jiterator 仅支持 CUDA 和 ROCm GPU,没有可用的 GPU。"
assert len(tensors) <= 8, "jiterator 仅支持最多 8 个张量输入。"
expanded_kwargs = self.kwargs_dict.copy()
for key, value in kwargs.items():
if key in self.kwargs_dict:
expanded_kwargs[key] = value
else:
raise KeyError(f"{key} 未在函数定义中声明")
return torch._C._cuda_jiterator_compile_and_launch_kernel(
self.code_string,
self.kernel_name,
self.return_by_ref,
self.num_outputs,
tensors,
expanded_kwargs,
)
[docs]def _create_jit_fn(code_string: str, **kwargs) -> Callable:
"""
为元素操作创建一个由 jiterator 生成的 cuda 内核。
代码字符串必须是描述单个元素计算的有效 CUDA 函数。代码
字符串必须遵循 c++ 模板模式,如下例所示。此函数将被内联
到元素操作内核模板中,并动态编译。编译的内核将被缓存到内存中,以及
本地临时目录中。
由 jiterator 生成的内核接受非连续张量,并支持广播和类型提升。
参数:
code_string (str): 由 jiterator 编译的 CUDA 代码字符串。入口函数必须按值返回。
kwargs (Dict, 可选): 生成函数的键值参数
示例::
code_string = "template T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
<span