IR模块
本教程介绍Apache TVM Unity的核心抽象概念——IRModule。 IRModule包含了ML模型的全部内容,整合了计算图、张量程序以及可能的外部库调用。
import numpy as np
import tvm
from tvm import relax
创建IR模块
IRModule可以通过多种方式进行初始化。下面我们将演示其中几种方法。
import torch
from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
从现有模型导入
初始化IRModule最常见的方式是从现有模型导入。Apache TVM Unity支持从多种框架导入,例如PyTorch和ONNX。本教程仅演示从PyTorch导入的过程。
# Create a dummy model
class TorchModel(nn.Module):
def __init__(self):
super(TorchModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
# Give an example argument to torch.export
example_args = (torch.randn(1, 784, dtype=torch.float32),)
# Convert the model to IRModule
with torch.no_grad():
exported_program = export(TorchModel().eval(), example_args)
mod_from_torch = from_exported_program(
exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
)
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
# Print the IRModule
mod_from_torch.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
使用Relax NN模块轻松编写
Apache TVM Unity还提供了一套类似PyTorch的API,帮助用户直接编写IRModule。
from tvm.relax.frontend import nn
class RelaxModel(nn.Module):
def __init__(self):
super(RelaxModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
mod_from_relax, params_from_relax = RelaxModel().export_tvm(
{"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod_from_relax.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
通过TVMScript创建
TVMScript是一种基于Python的领域特定语言(DSL),用于处理IRModules。我们能够直接以TVMScript语法输出IRModule,或者通过解析TVMScript来获取IRModule。
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class TVMScriptModule:
@R.function
def main(
x: R.Tensor((1, 784), dtype="float32"),
fc1_weight: R.Tensor((256, 784), dtype="float32"),
fc1_bias: R.Tensor((256,), dtype="float32"),
fc2_weight: R.Tensor((10, 256), dtype="float32"),
fc2_bias: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims = R.permute_dims(fc1_weight, axes=None)
matmul = R.matmul(x, permute_dims, out_dtype="void")
add = R.add(matmul, fc1_bias)
relu = R.nn.relu(add)
permute_dims1 = R.permute_dims(fc2_weight, axes=None)
matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
add1 = R.add(matmul1, fc2_bias)
gv = add1
R.output(gv)
return gv
mod_from_script = TVMScriptModule
mod_from_script.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
IRModule的属性
IRModule是一个函数的集合,通过GlobalVars进行索引。
mod = mod_from_torch
print(mod.get_global_vars())
[I.GlobalVar("main")]
我们可以通过使用GlobalVars或其名称索引来访问IRModule中的函数
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(p_fc1_weight, axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(x, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, p_fc1_bias)
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(p_fc2_weight, axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, p_fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
IRModule上的转换
转换是Apache TVM Unity的核心组成部分。一个转换接收一个IRModule并输出另一个IRModule。我们可以对IRModule应用一系列转换来获得新的IRModule。这是优化模型的常用方法。
在本入门教程中,我们仅演示如何对IRModule应用转换。有关每个转换的详细信息,请参阅Transformation API Reference
我们首先对IRModule应用LegalizeOps转换。该转换会将Relax模块转换为混合阶段,在同一个模块中同时包含Relax和TensorIR函数。同时,Relax运算符将被转换为call_tir。
mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(lv1: T.Buffer((T.int64(1), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv1[v_ax0, v_ax1], p_fc1_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv1[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
@T.prim_func(private=True)
def add1(lv5: T.Buffer((T.int64(1), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv5[v_ax0, v_ax1], p_fc2_bias[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + p_fc2_bias[v_ax1]
@T.prim_func(private=True)
def matmul(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
@T.prim_func(private=True)
def matmul1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
@T.prim_func(private=True)
def relu(lv2: T.Buffer((T.int64(1), T.int64(256)), "float32"), compute: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv2[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(lv2[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv1 = R.call_tir(cls.matmul, (x, lv), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv2 = R.call_tir(cls.add, (lv1, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
lv5 = R.call_tir(cls.matmul1, (lv3, lv4), out_sinfo=R.Tensor((1, 10), dtype="float32"))
lv6 = R.call_tir(cls.add1, (lv5, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
转换后,模块内部的功能变得更加丰富。让我们再次打印全局变量。
print(mod.get_global_vars())
[I.GlobalVar("add"), I.GlobalVar("add1"), I.GlobalVar("main"), I.GlobalVar("matmul"), I.GlobalVar("matmul1"), I.GlobalVar("relu"), I.GlobalVar("transpose"), I.GlobalVar("transpose1")]
接下来,Apache TVM Unity 为用户提供了一套默认的转换流水线,以简化转换过程。然后我们可以将默认流水线应用到模块中。默认的zero流水线包含非常基础的转换,包括:
LegalizeOps: 该转换将Relax运算符转化为call_tir函数,并关联对应的TensorIR函数。转换完成后,IRModule中将同时包含Relax函数和TensorIR函数。
AnnotateTIROpPattern: 该转换用于标注TensorIR函数的模式,为后续的算子融合做准备。
FoldConstant: 该过程执行常量折叠优化,优化涉及常量的操作。
FuseOps与FuseTIR: 这两个过程协同工作,基于前一步骤(AnnotateTIROpPattern)中标注的模式来融合算子。这些过程会同时转换Relax函数和TensorIR函数。
注意
在这里,我们在流程中应用了LegalizeOps两次。第二次虽然无用但无害。
每个pass都可以在流程中被复制,因为我们确保这些pass能够处理所有合法的IRModule输入。这种设计可以帮助用户构建他们自己的流水线。
mod = relax.get_pipeline("zero")(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), lv4: T.Buffer((T.int64(256), T.int64(10)), "float32"), p_fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], lv4[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * lv4[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), lv: T.Buffer((T.int64(784), T.int64(256)), "float32"), p_fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], lv[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * lv[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], p_fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + p_fc1_bias[v_ax1]
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(p_fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(p_fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p_fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = p_fc2_weight[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32"), p_fc1_weight: R.Tensor((256, 784), dtype="float32"), p_fc1_bias: R.Tensor((256,), dtype="float32"), p_fc2_weight: R.Tensor((10, 256), dtype="float32"), p_fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (p_fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv_1 = R.call_tir(cls.fused_matmul_add_relu, (x, lv, p_fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (p_fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv_1, lv4, p_fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
通用化部署IRModule
优化完成后,我们可以将模型编译为tvm运行时模块。值得注意的是,Apache TVM Unity提供了通用部署能力,这意味着我们可以在不同后端(包括CPU、GPU和其他新兴后端)上部署相同的IRModule。
在CPU上部署
我们可以通过将目标指定为llvm来在CPU上部署IRModule。
exec = tvm.compile(mod, target="llvm")
dev = tvm.cpu()
vm = relax.VirtualMachine(exec, dev)
raw_data = np.random.rand(1, 784).astype("float32")
data = tvm.nd.array(raw_data, dev)
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
print(cpu_out)
[[ 0.05386125 0.16882402 -0.01117187 -0.0300092 -0.18779516 0.23640567
0.08706295 -0.00564075 -0.09215595 0.11760191]]
在GPU上部署
除了CPU后端,我们还可以将IRModule部署到GPU上。GPU需要程序包含额外信息,例如线程绑定和共享内存分配。我们需要进一步转换以生成GPU程序。
我们使用DLight来生成GPU程序。在本教程中,我们不会深入探讨DLight的细节。
from tvm import dlight as dl
with tvm.target.Target("cuda"):
gpu_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.Fallback(),
)(mod)
现在我们可以在GPU上编译IRModule,方法与在CPU上的操作类似。
exec = tvm.compile(gpu_mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(exec, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(raw_data, dev)
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
gpu_out = vm["main"](data, *gpu_params).numpy()
print(gpu_out)
# Check the correctness of the results
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
[[ 0.05386121 0.1688241 -0.01117178 -0.0300092 -0.18779515 0.23640573
0.08706297 -0.0056407 -0.09215595 0.11760198]]
在其他后端部署
Apache TVM Unity 还支持其他后端,例如各类GPU平台 (Metal、ROCm、Vulkan和OpenCL),多种CPU架构(x86、ARM),以及 新兴后端(如WebAssembly)。其部署流程与GPU后端类似。