JAX: 高性能数组计算

目录

JAX: 高性能数组计算#

JAX 是一个面向加速器的数组计算和程序转换的 Python 库,专为高性能数值计算和大规模机器学习设计。

熟悉的API

JAX 提供了一个熟悉的 NumPy 风格 API,以便研究人员和工程师易于采用。

变换

JAX 包含了用于编译、批处理、自动微分和并行化的可组合函数变换。

随处运行

相同的代码可以在多个后端执行,包括CPU、GPU和TPU。

入门指南
开始使用 JAX
用户指南
用户指南
开发者笔记
开发者笔记

如果你想训练神经网络,使用 Flax 并从其教程开始。对于一个基于 JAX 的端到端 transformer 库,请参见 MaxText

生态系统#

JAX 本身范围较窄,专注于高效的数组操作和程序转换。围绕 JAX 构建的是一个不断发展的机器学习和数值计算工具生态系统;以下只是其中的一小部分示例:

神经网络

优化器与求解器

杂项工具

概率编程

物理与模拟

已经开发了许多基于JAX的库;社区维护的 Awesome JAX 页面保持了一个最新的列表。