triton.language.device_print¶
- triton.language.device_print(prefix, *args, hex=False)¶
在运行时从设备打印数值。字符串格式化不适用于运行时数值,因此您应该提供想要打印的数值作为参数。第一个值必须是字符串,所有后续值必须是标量或张量。
调用Python内置函数
print
与调用此函数相同,参数要求将匹配此函数(而非print
的正常要求)。tl.device_print("pid", pid) print("pid", pid)
在CUDA上,printf输出是通过一个有限大小的缓冲区进行流式传输的(在一台主机上,我们测得默认大小为6912 KiB,但这个值可能因GPU和CUDA版本而异)。如果您注意到某些printf输出被丢弃,可以通过调用以下方法来增加缓冲区大小
triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)
如果在运行使用printfs的内核后尝试更改此值,CUDA可能会报错。此处设置的值可能仅影响当前设备(因此如果您有多个GPU,则需要多次调用)。
- Parameters:
prefix – 在值之前打印的前缀。这必须是一个字符串字面量。
args – 要打印的值。可以是任何张量或标量。
hex – 将所有值以十六进制而非十进制形式打印