MLX中的自定义扩展#
您可以在CPU或GPU上使用自定义操作扩展MLX。本指南通过一个简单的示例解释了如何做到这一点。
示例介绍#
假设你想要一个操作,它接收两个数组,x 和 y,分别用系数 alpha 和 beta 缩放它们,然后将它们相加得到结果 z = alpha * x + beta * y。你可以直接在 MLX 中完成这个操作:
import mlx.core as mx
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
此函数执行该操作,同时将实现和函数转换留给MLX。
然而,您可能需要自定义底层实现,可能是为了使其更快或进行自定义微分。在本教程中,我们将介绍如何添加自定义扩展。它将涵盖:
MLX库的结构。
实现一个CPU操作,在适当的时候重定向到Accelerate。
使用Metal实现GPU操作。
添加
vjp和jvp函数转换。构建自定义扩展并将其绑定到python。
操作和原语#
MLX中的操作构建了计算图。原语提供了评估和转换图的规则。让我们从更详细地讨论操作开始。
操作#
操作是操作数组的前端函数。它们在C++ API(Operations)中定义,Python API(Operations)则对它们进行了绑定。
我们想要一个操作,axpby(),它接收两个数组x和y,以及两个标量alpha和beta。这是在C++中定义它的方法:
/**
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
执行此操作的最简单方法是使用现有操作:
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
操作本身并不包含作用于数据的实现,也不包含转换规则。相反,它们是一个易于使用的接口,使用Primitive构建块。
基本类型#
一个Primitive是array计算图的一部分。它定义了如何根据输入数组创建输出数组。此外,Primitive具有在CPU或GPU上运行的方法,以及用于函数转换的方法,如vjp和jvp。让我们回到我们的例子以更具体地说明:
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
};
Axpby 类继承自基类 Primitive。Axpby 将 alpha 和 beta 视为参数。然后,它通过 Axpby::eval_cpu() 和 Axpby::eval_gpu() 提供了如何根据输入生成输出数组的实现。它还提供了在 Axpby::jvp()、Axpby::vjp() 和 Axpby::vmap() 中的转换规则。
使用原始类型#
操作可以使用这个Primitive来向计算图中添加一个新的array。一个array可以通过提供其数据类型、形状、计算它的Primitive以及传递给该原语的array输入来构建。
现在让我们根据我们的Axpby原语重新实现我们的操作。
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
此操作现在处理以下内容:
向上转换输入并解析输出数据类型。
广播输入并解析输出形状。
使用给定的流
alpha和beta构造原始Axpby。使用原始数据和输入构建输出
array。
实现原始#
当我们单独调用操作时,不会发生任何计算。该操作仅构建计算图。当我们评估输出数组时,MLX会安排计算图的执行,并根据用户指定的流/设备调用Axpby::eval_cpu()或Axpby::eval_gpu()。
警告
当调用Primitive::eval_cpu()或Primitive::eval_gpu()时,
尚未为输出数组分配内存。这些函数的实现需要根据需要分配内存。
实现CPU后端#
让我们从实现一个简单且通用的Axpby::eval_cpu()版本开始。我们之前将其声明为Axpby的私有成员函数,称为Axpby::eval()。
我们的简单方法将遍历输出数组的每个元素,找到对应的输入元素 x 和 y 并逐点执行操作。这在模板函数 axpby_impl() 中得到了体现。
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
我们的实现应该适用于所有传入的浮点数组。
因此,我们为 float32, float16, bfloat16 和
complex64 添加了调度。如果我们遇到意外的类型,我们会抛出一个错误。
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
}
}
这是一个很好的备用实现。在某些情况下,我们可以使用Accelerate框架提供的axpby例程来实现更快的性能:
Accelerate 不提供针对半精度浮点数的
axpby实现。我们只能将其用于float32类型。Accelerate 假设输入
x和y是连续的,并且所有元素之间都有固定的步幅。我们只有在x和y都是行连续或列连续时才会指向 Accelerate。Accelerate 执行常规操作
Y = (alpha * X) + (beta * Y)在原地进行。 MLX 期望将输出写入一个新数组。我们必须将y的元素复制到输出中,并将其用作axpby的输入。
让我们编写一个在适当条件下使用Accelerate的实现。
它为输出分配数据,将y复制到其中,然后调用
catlas_saxpby()从accelerate。
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
对于不符合加速标准的输入,我们回退到
Axpby::eval()。考虑到这一点,让我们完成我们的
Axpby::eval_cpu()。
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
}
仅此就足以在CPU流上运行操作axpby()!如果您不打算在GPU上运行此操作或对包含Axpby的计算图使用转换,您可以在此停止实现原语,并享受Accelerate库带来的加速。
实现GPU后端#
Apple silicon设备使用Metal着色语言来访问其GPU,而MLX中的GPU内核是使用Metal编写的。
让我们保持GPU内核简单。我们将启动与输出中元素数量完全相同的线程。每个线程将从x和y中选择它需要的元素,执行点对点操作,并更新其在输出中分配的元素。
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
然后我们需要为所有浮点类型实例化此模板,并为每个实例化分配一个唯一的主机名,以便我们能够识别它。
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
确定内核、设置输入、解析网格维度并调度到GPU的逻辑包含在Axpby::eval_gpu()中,如下所示。
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
我们现在可以在CPU和GPU上调用axpby()操作了!
在继续之前,关于MLX和Metal需要注意的几点。MLX会跟踪活动的command_buffer以及与之关联的MTLCommandBuffer。我们依赖d.get_command_encoder()来获取活动的Metal计算命令编码器,而不是构建一个新的并在最后调用compute_encoder->end_encoding()。MLX会将内核(计算管道)添加到活动的命令缓冲区中,直到达到某个指定的限制或需要刷新命令缓冲区以进行同步。
原始转换#
接下来,让我们为Primitive中的转换添加实现。
这些转换可以建立在其他操作之上,包括我们刚刚定义的操作:
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
注意,开始使用Primitive时,不需要完全定义转换。
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
}
构建和绑定#
让我们首先看一下整体目录结构。
extensions/axpby/定义了C++扩展库extensions/mlx_sample_extensions设置了相关Python包的结构extensions/bindings.cpp为我们的操作提供了Python绑定extensions/CMakeLists.txt包含用于构建库和Python绑定的CMake规则extensions/setup.py包含用于构建和安装Python包的setuptools规则
绑定到Python#
我们使用nanobind为C++库构建一个Python API。由于已经提供了如mlx.core.array、mlx.core.stream等组件的绑定,添加我们的axpby()非常简单。
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)");
}
上述示例中的大部分复杂性来自于额外的装饰,例如字面名称和文档字符串。
警告
mlx.core 必须在导入 mlx_sample_extensions 之前导入,如上面 nanobind 模块所定义,以确保 mlx.core 组件(如 mlx.core.array)的转换器可用。
使用CMake构建#
构建C++扩展库只需要你find_package(MLX
CONFIG)然后将其链接到你的库。
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
我们还需要构建附带的Metal库。为了方便起见,我们提供了一个mlx_build_metallib()函数,该函数根据给定的源文件、头文件、目标路径等构建一个.metallib目标(在cmake/extension.cmake中定义,并自动与MLX包一起导入)。
以下是实际应用中的样子:
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
最后,我们构建了nanobind绑定
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
使用setuptools#构建
一旦我们按照上述方式设置了CMake构建规则,我们就可以使用mlx.extension中定义的构建工具:
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
zip_safe=False,
python_requires=">=3.8",
)
注意
我们将extensions/mlx_sample_extensions视为包目录,即使它只包含一个__init__.py以确保以下内容:
mlx.core必须在导入_ext之前导入C++扩展库和metal库与Python绑定位于同一位置,如果安装了该包,它们将一起被复制。
要构建包,首先使用pip install -r requirements.txt安装构建依赖项。然后,您可以使用python setup.py build_ext -j8 --inplace(在extensions/中)进行开发中的就地构建。
这导致了目录结构:
当你尝试使用命令 python -m pip install . 进行安装时(在
extensions/ 目录下),该包将以与
extensions/mlx_sample_extensions 相同的结构安装,并且 C++ 和 Metal 库将与 Python 绑定一起复制,因为它们被指定为
package_data。
用法#
按照上述方法安装扩展后,您应该能够简单地导入Python包并像使用其他MLX操作一样使用它。
让我们来看一个简单的脚本及其结果:
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
输出:
c shape: [3, 4]
c dtype: float32
c correctness: True
结果#
让我们运行一个快速的基准测试,看看我们新的axpby操作与我们最初在CPU上定义的简单simple_axpby()相比如何。
import mlx.core as mx
from mlx_sample_extensions import axpby
import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 256
N = 512
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0
mx.eval(x, y)
def bench(f):
# Warm up
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(5000):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return e - s
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
结果是 Simple axpby: 0.114 s | Custom axpby: 0.109 s。我们立即看到了适度的改进!
此操作现在可以很好地用于构建其他操作,在
mlx.nn.Module 调用中,也可以作为图形转换的一部分,如
grad()。
脚本#
下载代码
完整的示例代码可在mlx中找到。