DGL 外部函数接口 (FFI)

我们都喜欢Python,因为它易于操作。我们都喜欢C,因为它快速、可靠且类型化。为了兼具两者的优点,DGL主要在Python中,以便快速原型设计,同时将性能关键部分降低到C。因此,DGL开发者经常面临编写C例程并通过一种称为外部函数接口(FFI)的机制将其暴露给Python的场景。

市面上有许多FFI解决方案。在DGL中,我们希望保持其简单、直观且高效,以应对关键用例。这就是为什么当我们遇到TVM项目中的FFI解决方案时,我们立即被它吸引。它利用了函数式编程的思想,因此只暴露了十几个C API,并且可以在此基础上构建新的API。

我们决定(毫不羞耻地)借用这个想法。例如,定义一个暴露给python的C API只需要几行代码:

// file: calculator.cc (put it in dgl/src folder)
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>

using namespace dgl::runtime;

DGL_REGISTER_GLOBAL("calculator.MyAdd")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    int a = args[0];
    int b = args[1];
    *rv = a + b;
  });

编译并构建库。在python端,在dgl/python/dgl/下创建一个calculator.py文件。

# file: calculator.py
from ._ffi.function import _init_api

def add(a, b):
  # MyAdd has been registered via `_ini_api` call below
  return MyAdd(a, b)

_init_api("dgl.calculator")

关键在于FFI系统首先会屏蔽函数参数的类型信息,因此所有的C函数调用都可以通过一个C API(DGLFuncCall)进行。类型信息在函数体内通过静态转换来获取,并且我们会进行运行时类型检查以确保类型转换是正确的。只要函数调用不是太轻量级(上面的例子实际上是一个不好的例子),这种来回的开销是可以忽略的。TVM的PackedFunc文档有更多详细信息。

定义新类型

DGLArgsDGLRetValue 仅支持有限数量的类型:

  • 数值类型:int, float, double, …

  • 字符串

  • 函数(以PackedFunc的形式)

  • NDArray

尽管有限,上述类型系统非常强大,因为它支持函数作为一等公民。例如,如果你想返回多个值,你可以返回一个PackedFunc,它根据给定的整数索引返回每个值。然而,在许多情况下,仍然需要新类型来简化开发过程:

  • 参数/返回值是集合的组合(例如,列表的字典的字典)。

  • 有时我们只想要一个“结构”的概念(例如,给定一个苹果,通过apple.color获取它的颜色)。

为了实现这一点,我们引入了对象类型系统。例如,定义一个新类型 Calculator

// file: calculator.cc
#include <dgl/packed_func_ext.h>
using namespace runtime;
class CalculatorObject : public Object {
 public:
  std::string brand;
  int price;

  void VisitAttrs(AttrVisitor *v) final {
    v->Visit("brand", &brand);
    v->Visit("price", &price);
  }

  static constexpr const char* _type_key = "Calculator";
  DGL_DECLARE_OBJECT_TYPE_INFO(CalculatorObject, Object);
};

// This is to define a reference class (the wrapper of an object shared pointer).
// A minimal implementation is as follows, but you could define extra methods.
class Calculator : public ObjectRef {
 public:
  const CalculatorObject* operator->() const {
    return static_cast<const CalculatorObject*>(obj_.get());
  }
  using ContainerType = CalculatorObject;
};

DGL_REGISTER_GLOBAL("calculator.CreateCaculator")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  std::string brand = args[0];
  int price = args[1];
  auto o = std::make_shared<CalculatorObject>();
  o->brand = brand;
  o->price = price;
  *rv = o;
}

在 Python 方面:

# file: calculator.py
from dgl._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api

@register_object
class Calculator(ObjectBase):
  @staticmethod
  def create(brand, price):
    # invoke a C API, the return value is of `Calculator` type
    return CreateCalculator(brand, price)

_init_api("dgl.calculator")

然后我们可以简单地通过以下方式创建 Calculator 对象:

calc = Calculator.create("casio", 100)

这个对象的好处在于,它定义了一个访问者模式,这本质上是一种反射机制,用于获取其内部属性。例如,你可以通过简单地访问其属性来打印计算器的品牌。

print(calc.brand)
print(calc.price)

由于字符串键查找,反射确实有点慢。为了加快速度,你可以定义一个属性访问API:

// file: calculator.cc
DGL_REGISTER_GLOBAL("calculator.CaculatorGetBrand")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  Calculator calc = args[0];
  *rv = calc->brand;
}

容器

容器也是对象。例如,下面的C API接受一个整数列表并返回它们的总和:

// in file: calculator.cc
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.Sum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  // All the DGL supported values are represented as a ValueObject, which
  //   contains a data field.
  List<Value> values = args[0];
  int sum = 0;
  for (int i = 0; i < values.size(); ++i) {
    sum += static_cast<int>(values[i]->data);
  }
}

调用此API很简单——只需传递一个python整数列表。DGL FFI会自动将python列表/元组/字典转换为相应的对象类型。

# in file: calculator.py
from ._ffi.function import _init_api

Sum([0, 1, 2, 3, 4, 5])

_init_api("dgl.calculator")

容器中的元素可以是任何对象,这使得容器可以被组合。下面是一个接受计算器列表并打印出它们价格的API:

// in file: calculator.cc
#include <iostream>
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.PrintCalculators")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  List<Calculator> calcs = args[0];
  for (int i = 0; i < calcs.size(); ++i) {
    std::cout << calcs[i]->price << std::endl;
  }
}

请注意,容器不适用于从/向C API传递大量项目。在这些情况下,速度会非常慢。建议先进行基准测试。作为替代方案,对于大量数值,请使用NDArray,并使用dgl.batch将大量DGLGraph批量处理为单个DGLGraph