MLX中的自定义扩展#

您可以在CPU或GPU上使用自定义操作扩展MLX。本指南通过一个简单的示例解释了如何做到这一点。

示例介绍#

假设你想要一个操作,它接收两个数组,xy,分别用系数 alphabeta 缩放它们,然后将它们相加得到结果 z = alpha * x + beta * y。你可以直接在 MLX 中完成这个操作:

import mlx.core as mx

def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
    return alpha * x + beta * y

此函数执行该操作,同时将实现和函数转换留给MLX。

然而,您可能需要自定义底层实现,可能是为了使其更快或进行自定义微分。在本教程中,我们将介绍如何添加自定义扩展。它将涵盖:

  • MLX库的结构。

  • 实现一个CPU操作,在适当的时候重定向到Accelerate

  • 使用Metal实现GPU操作。

  • 添加vjpjvp函数转换。

  • 构建自定义扩展并将其绑定到python。

操作和原语#

MLX中的操作构建了计算图。原语提供了评估和转换图的规则。让我们从更详细地讨论操作开始。

操作#

操作是操作数组的前端函数。它们在C++ API(Operations)中定义,Python API(Operations)则对它们进行了绑定。

我们想要一个操作,axpby(),它接收两个数组xy,以及两个标量alphabeta。这是在C++中定义它的方法:

/**
*  Scale and sum two vectors element-wise
*  z = alpha * x + beta * y
*
*  Follow numpy style broadcasting between x and y
*  Inputs are upcasted to floats if needed
**/
array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s = {} // Stream on which to schedule the operation
);

执行此操作的最简单方法是使用现有操作:

array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
    // Scale x and y on the provided stream
    auto ax = multiply(array(alpha), x, s);
    auto by = multiply(array(beta), y, s);

    // Add and return
    return add(ax, by, s);
}

操作本身并不包含作用于数据的实现,也不包含转换规则。相反,它们是一个易于使用的接口,使用Primitive构建块。

基本类型#

一个Primitivearray计算图的一部分。它定义了如何根据输入数组创建输出数组。此外,Primitive具有在CPU或GPU上运行的方法,以及用于函数转换的方法,如vjpjvp。让我们回到我们的例子以更具体地说明:

class Axpby : public Primitive {
  public:
    explicit Axpby(Stream stream, float alpha, float beta)
        : Primitive(stream), alpha_(alpha), beta_(beta){};

    /**
    * A primitive must know how to evaluate itself on the CPU/GPU
    * for the given inputs and populate the output array.
    *
    * To avoid unnecessary allocations, the evaluation function
    * is responsible for allocating space for the array.
    */
    void eval_cpu(
        const std::vector<array>& inputs,
        std::vector<array>& outputs) override;
    void eval_gpu(
        const std::vector<array>& inputs,
        std::vector<array>& outputs) override;

    /** The Jacobian-vector product. */
    std::vector<array> jvp(
        const std::vector<array>& primals,
        const std::vector<array>& tangents,
        const std::vector<int>& argnums) override;

    /** The vector-Jacobian product. */
    std::vector<array> vjp(
        const std::vector<array>& primals,
        const array& cotan,
        const std::vector<int>& argnums,
        const std::vector<array>& outputs) override;

    /**
    * The primitive must know how to vectorize itself across
    * the given axes. The output is a pair containing the array
    * representing the vectorized computation and the axis which
    * corresponds to the output vectorized dimension.
    */
    virtual std::pair<std::vector<array>, std::vector<int>> vmap(
        const std::vector<array>& inputs,
        const std::vector<int>& axes) override;

    /** Print the primitive. */
    void print(std::ostream& os) override {
        os << "Axpby";
    }

    /** Equivalence check **/
    bool is_equivalent(const Primitive& other) const override;

  private:
    float alpha_;
    float beta_;

    /** Fall back implementation for evaluation on CPU */
    void eval(const std::vector<array>& inputs, array& out);
};

Axpby 类继承自基类 PrimitiveAxpbyalphabeta 视为参数。然后,它通过 Axpby::eval_cpu()Axpby::eval_gpu() 提供了如何根据输入生成输出数组的实现。它还提供了在 Axpby::jvp()Axpby::vjp()Axpby::vmap() 中的转换规则。

使用原始类型#

操作可以使用这个Primitive来向计算图中添加一个新的array。一个array可以通过提供其数据类型、形状、计算它的Primitive以及传递给该原语的array输入来构建。

现在让我们根据我们的Axpby原语重新实现我们的操作。

array axpby(
    const array& x, // Input array x
    const array& y, // Input array y
    const float alpha, // Scaling factor for x
    const float beta, // Scaling factor for y
    StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
    // Promote dtypes between x and y as needed
    auto promoted_dtype = promote_types(x.dtype(), y.dtype());

    // Upcast to float32 for non-floating point inputs x and y
    auto out_dtype = is_floating_point(promoted_dtype)
        ? promoted_dtype
        : promote_types(promoted_dtype, float32);

    // Cast x and y up to the determined dtype (on the same stream s)
    auto x_casted = astype(x, out_dtype, s);
    auto y_casted = astype(y, out_dtype, s);

    // Broadcast the shapes of x and y (on the same stream s)
    auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
    auto out_shape = broadcasted_inputs[0].shape();

    // Construct the array as the output of the Axpby primitive
    // with the broadcasted and upcasted arrays as inputs
    return array(
        /* const std::vector<int>& shape = */ out_shape,
        /* Dtype dtype = */ out_dtype,
        /* std::unique_ptr<Primitive> primitive = */
        std::make_shared<Axpby>(to_stream(s), alpha, beta),
        /* const std::vector<array>& inputs = */ broadcasted_inputs);
}

此操作现在处理以下内容:

  1. 向上转换输入并解析输出数据类型。

  2. 广播输入并解析输出形状。

  3. 使用给定的流 alphabeta 构造原始 Axpby

  4. 使用原始数据和输入构建输出array

实现原始#

当我们单独调用操作时,不会发生任何计算。该操作仅构建计算图。当我们评估输出数组时,MLX会安排计算图的执行,并根据用户指定的流/设备调用Axpby::eval_cpu()Axpby::eval_gpu()

警告

当调用Primitive::eval_cpu()Primitive::eval_gpu()时, 尚未为输出数组分配内存。这些函数的实现需要根据需要分配内存。

实现CPU后端#

让我们从实现一个简单且通用的Axpby::eval_cpu()版本开始。我们之前将其声明为Axpby的私有成员函数,称为Axpby::eval()

我们的简单方法将遍历输出数组的每个元素,找到对应的输入元素 xy 并逐点执行操作。这在模板函数 axpby_impl() 中得到了体现。

template <typename T>
void axpby_impl(
        const array& x,
        const array& y,
        array& out,
        float alpha_,
        float beta_) {
    // We only allocate memory when we are ready to fill the output
    // malloc_or_wait synchronously allocates available memory
    // There may be a wait executed here if the allocation is requested
    // under memory-pressured conditions
    out.set_data(allocator::malloc_or_wait(out.nbytes()));

    // Collect input and output data pointers
    const T* x_ptr = x.data<T>();
    const T* y_ptr = y.data<T>();
    T* out_ptr = out.data<T>();

    // Cast alpha and beta to the relevant types
    T alpha = static_cast<T>(alpha_);
    T beta = static_cast<T>(beta_);

    // Do the element-wise operation for each output
    for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
        // Map linear indices to offsets in x and y
        auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
        auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());

        // We allocate the output to be contiguous and regularly strided
        // (defaults to row major) and hence it doesn't need additional mapping
        out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
    }
}

我们的实现应该适用于所有传入的浮点数组。 因此,我们为 float32, float16, bfloat16complex64 添加了调度。如果我们遇到意外的类型,我们会抛出一个错误。

/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
  const std::vector<array>& inputs,
  const std::vector<array>& outputs) {
    auto& x = inputs[0];
    auto& y = inputs[1];
    auto& out = outputs[0];

    // Dispatch to the correct dtype
    if (out.dtype() == float32) {
        return axpby_impl<float>(x, y, out, alpha_, beta_);
    } else if (out.dtype() == float16) {
        return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
    } else if (out.dtype() == bfloat16) {
        return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
    } else if (out.dtype() == complex64) {
        return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
    } else {
        throw std::runtime_error(
            "[Axpby] Only supports floating point types.");
    }
}

这是一个很好的备用实现。在某些情况下,我们可以使用Accelerate框架提供的axpby例程来实现更快的性能:

  1. Accelerate 不提供针对半精度浮点数的 axpby 实现。我们只能将其用于 float32 类型。

  2. Accelerate 假设输入 xy 是连续的,并且所有元素之间都有固定的步幅。我们只有在 xy 都是行连续或列连续时才会指向 Accelerate。

  3. Accelerate 执行常规操作 Y = (alpha * X) + (beta * Y) 在原地进行。 MLX 期望将输出写入一个新数组。我们必须将 y 的元素复制到输出中,并将其用作 axpby 的输入。

让我们编写一个在适当条件下使用Accelerate的实现。 它为输出分配数据,将y复制到其中,然后调用 catlas_saxpby()从accelerate。

template <typename T>
void axpby_impl_accelerate(
        const array& x,
        const array& y,
        array& out,
        float alpha_,
        float beta_) {
    // Accelerate library provides catlas_saxpby which does
    // Y = (alpha * X) + (beta * Y) in place
    // To use it, we first copy the data in y over to the output array
    out.set_data(allocator::malloc_or_wait(out.nbytes()));

    // We then copy over the elements using the contiguous vector specialization
    copy_inplace(y, out, CopyType::Vector);

    // Get x and y pointers for catlas_saxpby
    const T* x_ptr = x.data<T>();
    T* y_ptr = out.data<T>();

    T alpha = static_cast<T>(alpha_);
    T beta = static_cast<T>(beta_);

    // Call the inplace accelerate operator
    catlas_saxpby(
        /* N = */ out.size(),
        /* ALPHA = */ alpha,
        /* X = */ x_ptr,
        /* INCX = */ 1,
        /* BETA = */ beta,
        /* Y = */ y_ptr,
        /* INCY = */ 1);
}

对于不符合加速标准的输入,我们回退到 Axpby::eval()。考虑到这一点,让我们完成我们的 Axpby::eval_cpu()

/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
  const std::vector<array>& inputs,
  const std::vector<array>& outputs) {
    assert(inputs.size() == 2);
    auto& x = inputs[0];
    auto& y = inputs[1];
    auto& out = outputs[0];

    // Accelerate specialization for contiguous single precision float arrays
    if (out.dtype() == float32 &&
        ((x.flags().row_contiguous && y.flags().row_contiguous) ||
        (x.flags().col_contiguous && y.flags().col_contiguous))) {
        axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
        return;
    }

    // Fall back to common back-end if specializations are not available
    eval(inputs, outputs);
}

仅此就足以在CPU流上运行操作axpby()!如果您不打算在GPU上运行此操作或对包含Axpby的计算图使用转换,您可以在此停止实现原语,并享受Accelerate库带来的加速。

实现GPU后端#

Apple silicon设备使用Metal着色语言来访问其GPU,而MLX中的GPU内核是使用Metal编写的。

注意

如果你是Metal的新手,这里有一些有用的资源:

让我们保持GPU内核简单。我们将启动与输出中元素数量完全相同的线程。每个线程将从xy中选择它需要的元素,执行点对点操作,并更新其在输出中分配的元素。

template <typename T>
[[kernel]] void axpby_general(
        device const T* x [[buffer(0)]],
        device const T* y [[buffer(1)]],
        device T* out [[buffer(2)]],
        constant const float& alpha [[buffer(3)]],
        constant const float& beta [[buffer(4)]],
        constant const int* shape [[buffer(5)]],
        constant const int64_t* x_strides [[buffer(6)]],
        constant const int64_t* y_strides [[buffer(7)]],
        constant const int& ndim [[buffer(8)]],
        uint index [[thread_position_in_grid]]) {
    // Convert linear indices to offsets in array
    auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
    auto y_offset = elem_to_loc(index, shape, y_strides, ndim);

    // Do the operation and update the output
    out[index] =
        static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}

然后我们需要为所有浮点类型实例化此模板,并为每个实例化分配一个唯一的主机名,以便我们能够识别它。

instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)

确定内核、设置输入、解析网格维度并调度到GPU的逻辑包含在Axpby::eval_gpu()中,如下所示。

/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
  const std::vector<array>& inputs,
  std::vector<array>& outputs) {
    // Prepare inputs
    assert(inputs.size() == 2);
    auto& x = inputs[0];
    auto& y = inputs[1];
    auto& out = outputs[0];

    // Each primitive carries the stream it should execute on
    // and each stream carries its device identifiers
    auto& s = stream();
    // We get the needed metal device using the stream
    auto& d = metal::device(s.device);

    // Allocate output memory
    out.set_data(allocator::malloc_or_wait(out.nbytes()));

    // Resolve name of kernel
    std::ostringstream kname;
    kname << "axpby_" << "general_" << type_to_name(out);

    // Make sure the metal library is available
    d.register_library("mlx_ext");

    // Make a kernel from this metal library
    auto kernel = d.get_kernel(kname.str(), "mlx_ext");

    // Prepare to encode kernel
    auto& compute_encoder = d.get_command_encoder(s.index);
    compute_encoder.set_compute_pipeline_state(kernel);

    // Kernel parameters are registered with buffer indices corresponding to
    // those in the kernel declaration at axpby.metal
    int ndim = out.ndim();
    size_t nelem = out.size();

    // Encode input arrays to kernel
    compute_encoder.set_input_array(x, 0);
    compute_encoder.set_input_array(y, 1);

    // Encode output arrays to kernel
    compute_encoder.set_output_array(out, 2);

    // Encode alpha and beta
    compute_encoder.set_bytes(alpha_, 3);
    compute_encoder.set_bytes(beta_, 4);

    // Encode shape, strides and ndim
    compute_encoder.set_vector_bytes(x.shape(), 5);
    compute_encoder.set_vector_bytes(x.strides(), 6);
    compute_encoder.set_bytes(y.strides(), 7);
    compute_encoder.set_bytes(ndim, 8);

    // We launch 1 thread for each input and make sure that the number of
    // threads in any given threadgroup is not higher than the max allowed
    size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());

    // Fix the 3D size of each threadgroup (in terms of threads)
    MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);

    // Fix the 3D size of the launch grid (in terms of threads)
    MTL::Size grid_dims = MTL::Size(nelem, 1, 1);

    // Launch the grid with the given number of threads divided among
    // the given threadgroups
    compute_encoder.dispatch_threads(grid_dims, group_dims);
}

我们现在可以在CPU和GPU上调用axpby()操作了!

在继续之前,关于MLX和Metal需要注意的几点。MLX会跟踪活动的command_buffer以及与之关联的MTLCommandBuffer。我们依赖d.get_command_encoder()来获取活动的Metal计算命令编码器,而不是构建一个新的并在最后调用compute_encoder->end_encoding()。MLX会将内核(计算管道)添加到活动的命令缓冲区中,直到达到某个指定的限制或需要刷新命令缓冲区以进行同步。

原始转换#

接下来,让我们为Primitive中的转换添加实现。 这些转换可以建立在其他操作之上,包括我们刚刚定义的操作:

/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
        const std::vector<array>& primals,
        const std::vector<array>& tangents,
        const std::vector<int>& argnums) {
    // Forward mode diff that pushes along the tangents
    // The jvp transform on the primitive can built with ops
    // that are scheduled on the same stream as the primitive

    // If argnums = {0}, we only push along x in which case the
    // jvp is just the tangent scaled by alpha
    // Similarly, if argnums = {1}, the jvp is just the tangent
    // scaled by beta
    if (argnums.size() > 1) {
        auto scale = argnums[0] == 0 ? alpha_ : beta_;
        auto scale_arr = array(scale, tangents[0].dtype());
        return {multiply(scale_arr, tangents[0], stream())};
    }
    // If, argnums = {0, 1}, we take contributions from both
    // which gives us jvp = tangent_x * alpha + tangent_y * beta
    else {
        return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
    }
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
        const std::vector<array>& primals,
        const std::vector<array>& cotangents,
        const std::vector<int>& argnums,
        const std::vector<int>& /* unused */) {
    // Reverse mode diff
    std::vector<array> vjps;
    for (auto arg : argnums) {
        auto scale = arg == 0 ? alpha_ : beta_;
        auto scale_arr = array(scale, cotangents[0].dtype());
        vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
    }
    return vjps;
}

注意,开始使用Primitive时,不需要完全定义转换。

/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
        const std::vector<array>& inputs,
        const std::vector<int>& axes) {
    throw std::runtime_error("[Axpby] vmap not implemented.");
}

构建和绑定#

让我们首先看一下整体目录结构。

扩展
├── axpby
│ ├── axpby.cpp
│ ├── axpby.h
│ └── axpby.metal
├── mlx_sample_extensions
│ └── __init__.py
├── bindings.cpp
├── CMakeLists.txt
└── setup.py
  • extensions/axpby/ 定义了C++扩展库

  • extensions/mlx_sample_extensions 设置了相关Python包的结构

  • extensions/bindings.cpp 为我们的操作提供了Python绑定

  • extensions/CMakeLists.txt 包含用于构建库和Python绑定的CMake规则

  • extensions/setup.py 包含用于构建和安装Python包的setuptools规则

绑定到Python#

我们使用nanobind为C++库构建一个Python API。由于已经提供了如mlx.core.arraymlx.core.stream等组件的绑定,添加我们的axpby()非常简单。

NB_MODULE(_ext, m) {
     m.doc() = "Sample extension for MLX";

     m.def(
         "axpby",
         &axpby,
         "x"_a,
         "y"_a,
         "alpha"_a,
         "beta"_a,
         nb::kw_only(),
         "stream"_a = nb::none(),
         R"(
             Scale and sum two vectors element-wise
             ``z = alpha * x + beta * y``

             Follows numpy style broadcasting between ``x`` and ``y``
             Inputs are upcasted to floats if needed

             Args:
                 x (array): Input array.
                 y (array): Input array.
                 alpha (float): Scaling factor for ``x``.
                 beta (float): Scaling factor for ``y``.

             Returns:
                 array: ``alpha * x + beta * y``
         )");
 }

上述示例中的大部分复杂性来自于额外的装饰,例如字面名称和文档字符串。

警告

mlx.core 必须在导入 mlx_sample_extensions 之前导入,如上面 nanobind 模块所定义,以确保 mlx.core 组件(如 mlx.core.array)的转换器可用。

使用CMake构建#

构建C++扩展库只需要你find_package(MLX CONFIG)然后将其链接到你的库。

# Add library
add_library(mlx_ext)

# Add sources
target_sources(
    mlx_ext
    PUBLIC
    ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)

# Add include headers
target_include_directories(
    mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)

# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)

我们还需要构建附带的Metal库。为了方便起见,我们提供了一个mlx_build_metallib()函数,该函数根据给定的源文件、头文件、目标路径等构建一个.metallib目标(在cmake/extension.cmake中定义,并自动与MLX包一起导入)。

以下是实际应用中的样子:

# Build metallib
if(MLX_BUILD_METAL)

mlx_build_metallib(
    TARGET mlx_ext_metallib
    TITLE mlx_ext
    SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
    INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
    OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)

add_dependencies(
    mlx_ext
    mlx_ext_metallib
)

endif()

最后,我们构建了nanobind绑定

nanobind_add_module(
  _ext
  NB_STATIC STABLE_ABI LTO NOMINSIZE
  NB_DOMAIN mlx
  ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)

if(BUILD_SHARED_LIBS)
  target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()

使用setuptools#构建

一旦我们按照上述方式设置了CMake构建规则,我们就可以使用mlx.extension中定义的构建工具:

from mlx import extension
from setuptools import setup

if __name__ == "__main__":
    setup(
        name="mlx_sample_extensions",
        version="0.0.0",
        description="Sample C++ and Metal extensions for MLX primitives.",
        ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
        cmdclass={"build_ext": extension.CMakeBuild},
        packages=["mlx_sample_extensions"],
        package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
        extras_require={"dev":[]},
        zip_safe=False,
        python_requires=">=3.8",
    )

注意

我们将extensions/mlx_sample_extensions视为包目录,即使它只包含一个__init__.py以确保以下内容:

  • mlx.core 必须在导入 _ext 之前导入

  • C++扩展库和metal库与Python绑定位于同一位置,如果安装了该包,它们将一起被复制。

要构建包,首先使用pip install -r requirements.txt安装构建依赖项。然后,您可以使用python setup.py build_ext -j8 --inplace(在extensions/中)进行开发中的就地构建。

这导致了目录结构:

扩展
├── mlx_sample_extensions
│ ├── __init__.py
│ ├── libmlx_ext.dylib # C++扩展库
│ ├── mlx_ext.metallib # 金属库
│ └── _ext.cpython-3x-darwin.so # Python 绑定
请提供需要翻译的html内容。

当你尝试使用命令 python -m pip install . 进行安装时(在 extensions/ 目录下),该包将以与 extensions/mlx_sample_extensions 相同的结构安装,并且 C++ 和 Metal 库将与 Python 绑定一起复制,因为它们被指定为 package_data

用法#

按照上述方法安装扩展后,您应该能够简单地导入Python包并像使用其他MLX操作一样使用它。

让我们来看一个简单的脚本及其结果:

import mlx.core as mx
from mlx_sample_extensions import axpby

a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)

print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

输出:

c shape: [3, 4]
c dtype: float32
c correctness: True

结果#

让我们运行一个快速的基准测试,看看我们新的axpby操作与我们最初在CPU上定义的简单simple_axpby()相比如何。

import mlx.core as mx
from mlx_sample_extensions import axpby
import time

mx.set_default_device(mx.cpu)

def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
    return alpha * x + beta * y

M = 256
N = 512

x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0

mx.eval(x, y)

def bench(f):
    # Warm up
    for i in range(100):
        z = f(x, y, alpha, beta)
        mx.eval(z)

    # Timed run
    s = time.time()
    for i in range(5000):
        z = f(x, y, alpha, beta)
        mx.eval(z)
    e = time.time()
    return e - s

simple_time = bench(simple_axpby)
custom_time = bench(axpby)

print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")

结果是 Simple axpby: 0.114 s | Custom axpby: 0.109 s。我们立即看到了适度的改进!

此操作现在可以很好地用于构建其他操作,在 mlx.nn.Module 调用中,也可以作为图形转换的一部分,如 grad()

脚本#

下载代码

完整的示例代码可在mlx中找到。