DGL 外部函数接口 (FFI)
我们都喜欢Python,因为它易于操作。我们都喜欢C,因为它快速、可靠且类型化。为了兼具两者的优点,DGL主要在Python中,以便快速原型设计,同时将性能关键部分降低到C。因此,DGL开发者经常面临编写C例程并通过一种称为外部函数接口(FFI)的机制将其暴露给Python的场景。
市面上有许多FFI解决方案。在DGL中,我们希望保持其简单、直观且高效,以应对关键用例。这就是为什么当我们遇到TVM项目中的FFI解决方案时,我们立即被它吸引。它利用了函数式编程的思想,因此只暴露了十几个C API,并且可以在此基础上构建新的API。
我们决定(毫不羞耻地)借用这个想法。例如,定义一个暴露给python的C API只需要几行代码:
// file: calculator.cc (put it in dgl/src folder)
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
using namespace dgl::runtime;
DGL_REGISTER_GLOBAL("calculator.MyAdd")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int a = args[0];
int b = args[1];
*rv = a + b;
});
编译并构建库。在python端,在dgl/python/dgl/下创建一个calculator.py文件。
# file: calculator.py
from ._ffi.function import _init_api
def add(a, b):
# MyAdd has been registered via `_ini_api` call below
return MyAdd(a, b)
_init_api("dgl.calculator")
关键在于FFI系统首先会屏蔽函数参数的类型信息,因此所有的C函数调用都可以通过一个C API(DGLFuncCall)进行。类型信息在函数体内通过静态转换来获取,并且我们会进行运行时类型检查以确保类型转换是正确的。只要函数调用不是太轻量级(上面的例子实际上是一个不好的例子),这种来回的开销是可以忽略的。TVM的PackedFunc文档有更多详细信息。
定义新类型
DGLArgs 和 DGLRetValue 仅支持有限数量的类型:
数值类型:int, float, double, …
字符串
函数(以PackedFunc的形式)
NDArray
尽管有限,上述类型系统非常强大,因为它支持函数作为一等公民。例如,如果你想返回多个值,你可以返回一个PackedFunc,它根据给定的整数索引返回每个值。然而,在许多情况下,仍然需要新类型来简化开发过程:
参数/返回值是集合的组合(例如,列表的字典的字典)。
有时我们只想要一个“结构”的概念(例如,给定一个苹果,通过
apple.color获取它的颜色)。
为了实现这一点,我们引入了对象类型系统。例如,定义一个新类型 Calculator:
// file: calculator.cc
#include <dgl/packed_func_ext.h>
using namespace runtime;
class CalculatorObject : public Object {
public:
std::string brand;
int price;
void VisitAttrs(AttrVisitor *v) final {
v->Visit("brand", &brand);
v->Visit("price", &price);
}
static constexpr const char* _type_key = "Calculator";
DGL_DECLARE_OBJECT_TYPE_INFO(CalculatorObject, Object);
};
// This is to define a reference class (the wrapper of an object shared pointer).
// A minimal implementation is as follows, but you could define extra methods.
class Calculator : public ObjectRef {
public:
const CalculatorObject* operator->() const {
return static_cast<const CalculatorObject*>(obj_.get());
}
using ContainerType = CalculatorObject;
};
DGL_REGISTER_GLOBAL("calculator.CreateCaculator")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string brand = args[0];
int price = args[1];
auto o = std::make_shared<CalculatorObject>();
o->brand = brand;
o->price = price;
*rv = o;
}
在 Python 方面:
# file: calculator.py
from dgl._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
@register_object
class Calculator(ObjectBase):
@staticmethod
def create(brand, price):
# invoke a C API, the return value is of `Calculator` type
return CreateCalculator(brand, price)
_init_api("dgl.calculator")
然后我们可以简单地通过以下方式创建 Calculator 对象:
calc = Calculator.create("casio", 100)
这个对象的好处在于,它定义了一个访问者模式,这本质上是一种反射机制,用于获取其内部属性。例如,你可以通过简单地访问其属性来打印计算器的品牌。
print(calc.brand)
print(calc.price)
由于字符串键查找,反射确实有点慢。为了加快速度,你可以定义一个属性访问API:
// file: calculator.cc
DGL_REGISTER_GLOBAL("calculator.CaculatorGetBrand")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Calculator calc = args[0];
*rv = calc->brand;
}
容器
容器也是对象。例如,下面的C API接受一个整数列表并返回它们的总和:
// in file: calculator.cc
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.Sum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// All the DGL supported values are represented as a ValueObject, which
// contains a data field.
List<Value> values = args[0];
int sum = 0;
for (int i = 0; i < values.size(); ++i) {
sum += static_cast<int>(values[i]->data);
}
}
调用此API很简单——只需传递一个python整数列表。DGL FFI会自动将python列表/元组/字典转换为相应的对象类型。
# in file: calculator.py
from ._ffi.function import _init_api
Sum([0, 1, 2, 3, 4, 5])
_init_api("dgl.calculator")
容器中的元素可以是任何对象,这使得容器可以被组合。下面是一个接受计算器列表并打印出它们价格的API:
// in file: calculator.cc
#include <iostream>
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.PrintCalculators")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<Calculator> calcs = args[0];
for (int i = 0; i < calcs.size(); ++i) {
std::cout << calcs[i]->price << std::endl;
}
}
请注意,容器不适用于从/向C API传递大量项目。在这些情况下,速度会非常慢。建议先进行基准测试。作为替代方案,对于大量数值,请使用NDArray,并使用dgl.batch将大量DGLGraph批量处理为单个DGLGraph。