跳至内容

运行时¤

tinygrad 支持多种运行时环境,使您的代码能够在各种设备上扩展运行。默认运行时可以根据可用硬件自动选择,或者您也可以通过环境变量强制指定默认运行时(例如 CPU=1)。

运行时 描述 要求
NV 为NVIDIA GPU提供加速支持 安培/Ada系列GPU
AMD 为AMD GPU提供加速支持 RDNA2/RDNA3/RDNA4系列GPU
QCOM 为QCOM GPU提供加速支持 6xx系列GPU
METAL Utilizes Metal for acceleration on Apple devices M1+ Macs; Metal 3.0+ for bfloat support
CUDA 利用CUDA在NVIDIA GPU上进行加速 支持CUDA的NVIDIA GPU
GPU (OpenCL) 使用GPU上的OpenCL加速计算 兼容OpenCL 2.0的设备
CPU (C Code) Runs on CPU using the clang compiler clang compiler in system PATH
LLVM (LLVM IR) 使用LLVM编译器基础设施在CPU上运行 已安装并可找到llvm库
WEBGPU 使用Dawn WebGPU引擎在GPU上运行(用于Google Chrome浏览器) 需要安装并能够找到Dawn库。可在此here下载二进制文件。

互操作性¤

tinygrad 提供与 OpenCL 和 PyTorch 的互操作性,通过 Tensor.from_blob API 实现框架间高效的张量数据共享。该功能通过直接操作外部内存指针实现零拷贝操作。

重要提示: 当使用外部内存指针与tinygrad张量时,必须确保这些指针在tinygrad张量的整个生命周期内保持有效,以防止内存损坏。

CUDA/METAL PyTorch 互操作性¤

您可以在PyTorch和tinygrad之间无缝使用CUDA/MPS张量而无需数据拷贝:

from tinygrad.dtype import _from_torch_dtype
tensor1 = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cuda"))
tiny_tensor1 = Tensor.from_blob(tensor1.data_ptr(), tensor1.shape, dtype=_from_torch_dtype(tensor1.dtype), device='CUDA')

# 在tinygrad计算之前,需要同步mps以确保数据有效
if data.device.type == "mps": torch.mps.synchronize()
else: torch.cuda.synchronize()

x = (tiny_tensor1 + 1).realize()

QCOM OpenCL 互操作性¤

tinygrad 在 QCOM 后端支持 OpenCL 互操作性。

缓冲区互操作允许直接访问OpenCL内存缓冲区:

# 创建原始opencl缓冲区
cl_buf = cl.clCreateBuffer(cl_context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32())

# 提取指针
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf, 8).cast('Q')[0]
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # 偏移量0xA0处是原始GPU指针

# 创建tiny张量
tiny = Tensor.from_blob(rawbuf_ptr, (8, 8), dtype=dtypes.int, device='QCOM')

对于图像也是同样的操作:

# 创建cl图像
cl_img = cl.clCreateImage2D(cl_context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT), w, h, 0, None, status := ctypes.c_int32())

# 提取指针
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_img), 8).cast('Q')[0]
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # 偏移量0xA0处是原始GPU指针

# 创建tiny张量
tiny = Tensor.from_blob(rawbuf_ptr, (h*w*4,), dtype=dtypes.imagef((h,w)), device='QCOM')