JAX: 高性能数组计算#
JAX 是一个面向加速器的数组计算和程序转换的 Python 库,专为高性能数值计算和大规模机器学习设计。
熟悉的API
JAX 提供了一个熟悉的 NumPy 风格 API,以便研究人员和工程师易于采用。
变换
JAX 包含了用于编译、批处理、自动微分和并行化的可组合函数变换。
随处运行
相同的代码可以在多个后端执行,包括CPU、GPU和TPU。
入门指南
用户指南
开发者笔记
如果你想训练神经网络,使用 Flax 并从其教程开始。对于一个基于 JAX 的端到端 transformer 库,请参见 MaxText。
生态系统#
JAX 本身范围较窄,专注于高效的数组操作和程序转换。围绕 JAX 构建的是一个不断发展的机器学习和数值计算工具生态系统;以下只是其中的一小部分示例:
优化器与求解器
已经开发了许多基于JAX的库;社区维护的 Awesome JAX 页面保持了一个最新的列表。