Shortcuts

torch.cuda.jiterator 的源代码

import re
from typing import Callable, List

import torch
from torch import Tensor

__all__: List[str] = []


class _CodeParser:
    def __init__(self, code_string: str):
        optional_ws = r"\s*"
        required_ws = r"\s+"
        template_params = r"(?P\<.+\>)"
        return_type = r"(?P\w+)"
        function_name = r"(?P\w+)"
        function_params = r"(?P\(.+\))"
        function_body = r"(?P\{.+\})"

        pattern = (
            optional_ws
            + "template"
            + optional_ws
            + template_params
            + optional_ws
            + return_type
            + required_ws
            + function_name
            + optional_ws
            + function_params
            + optional_ws
            + function_body
            + optional_ws
        )

        result = re.match(
            pattern, code_string, re.DOTALL
        )  # DOTALL 用于匹配多行

        if result is None:
            raise Exception(
                f"无法解析代码,请检查正确性:\n {code_string}"
            )

        self.template_params = result["template_params"]
        self.return_type = result["return_type"]
        self.function_name = result["function_name"]
        self.function_params = result["function_params"]
        self.function_body = result["function_body"]


class _JittedFunction:
    def __init__(
        self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
    ):
        self.code_string = code_string

        assert (
            return_by_ref or num_outputs == 1
        ), "按值返回仅适用于单个输出。"
        self.return_by_ref = return_by_ref
        self.num_outputs = num_outputs

        parsed_code = _CodeParser(code_string)
        self.kernel_name = parsed_code.function_name

        self.kwargs_dict = kwargs
        self.is_cuda_available = torch.cuda.is_available()

    def __call__(self, *tensors: Tensor, **kwargs):
        # Jiterator 遵循 torch.cuda 的延迟初始化行为
        # 推迟检查 cuda 的可用性到函数调用时
        assert (
            self.is_cuda_available
        ), "Jiterator 仅支持 CUDA 和 ROCm GPU,没有可用的 GPU。"

        assert len(tensors) <= 8, "jiterator 仅支持最多 8 个张量输入。"

        expanded_kwargs = self.kwargs_dict.copy()
        for key, value in kwargs.items():
            if key in self.kwargs_dict:
                expanded_kwargs[key] = value
            else:
                raise KeyError(f"{key} 未在函数定义中声明")

        return torch._C._cuda_jiterator_compile_and_launch_kernel(
            self.code_string,
            self.kernel_name,
            self.return_by_ref,
            self.num_outputs,
            tensors,
            expanded_kwargs,
        )


[docs]def _create_jit_fn(code_string: str, **kwargs) -> Callable: """ 为元素操作创建一个由 jiterator 生成的 cuda 内核。 代码字符串必须是描述单个元素计算的有效 CUDA 函数。代码 字符串必须遵循 c++ 模板模式,如下例所示。此函数将被内联 到元素操作内核模板中,并动态编译。编译的内核将被缓存到内存中,以及 本地临时目录中。 由 jiterator 生成的内核接受非连续张量,并支持广播和类型提升。 参数: code_string (str): 由 jiterator 编译的 CUDA 代码字符串。入口函数必须按值返回。 kwargs (Dict, 可选): 生成函数的键值参数 示例:: code_string = "template T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) <span