在 JAX 中排序副作用#

sharadmv@ 2022年5月9日

概述#

当我们编写 JAX 代码时,我们通常可以假装自己在编写单线程、急切执行的 Python 代码,尽管在底层,JAX 及其运行时可能会在后台异步执行它。只要我们编写纯(无副作用)代码,这些性能优化通常对我们来说是不可见的,并且不会干扰我们的单线程思维模型。异步执行很棒——我们无需考虑就能获得高性能、并行代码!

然而,在存在副作用的情况下,这种幻觉开始瓦解,我们心理模型的裂缝开始显现。具体来说,当我们考虑副作用发生的顺序时,这些差异就会显现出来。

在这篇设计笔记中,我们探讨了JAX执行模型与副作用顺序之间的关系。我们还提供了一种强制执行“单线程”副作用顺序的方法。

背景#

当我们编写以下Python代码时

def f():
  print("hello")
  return 2
def g():
  print("world")
  return 3
f()
g()

我们期望在 "world" 之前打印 "hello"。这可能看起来很明显,但请考虑以下 JAX 代码:

@partial(jax.jit, device=<device 0>)
def f():
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  return 3
f()
g()

在许多情况下,JAX 会并行执行 fg,将计算分派到不同的线程上——g 实际上可能在 f 之前执行。并行执行是一种很好的性能优化,特别是如果设备之间的复制成本很高(更多详情请参见 异步分派说明)。然而,在实践中,我们通常不需要考虑异步分派,因为我们编写的是纯函数,只关心函数的输入和输出——我们自然会在未来的值上阻塞。

然而,现在假设我们有一个 jax.print 函数,它可以在 JIT 编译的 JAX 函数内部工作(host_callback.id_print 是这种函数的一个例子)。让我们回到之前的例子,只是这次加入了打印功能。

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return 3
f()
g()

由于异步分派,我们实际上可以看到 "world""hello" 之前被打印出来。打印副作用的重排打破了单线程执行模型的假象。

另一个可以“揭示”乱序执行副作用的例子是当我们编译JAX程序时。考虑以下JAX代码:

@jax.jit
def f(x):
  jax.print("hello")
  jax.print("world")
  return x

尽管在Python中,我们在打印 "world" 之前写了打印 "hello",但像XLA这样的编译器可以自由地重新排序它们,因为打印之间没有明确的数据依赖关系。

动机#

我们希望支持“有序”效果。当我们说有序时,我们指的是效果发生的顺序与我们执行单线程 Python 程序时的顺序相同。这是我们的主要需求。在存在显式并行性(如 pmap 或用户线程)的情况下,我们不需要保持这种行为,但至少如果用户没有显式请求并行性,我们希望保留单线程顺序。

在我们深入探讨之前,让我们先退一步,问问自己,为了性能,我们是否可以重新排序效果,反之,我们是否需要完全强制执行效果的顺序?在某些情况下,我们不需要排序。也许某些副作用不应该对JAX程序的性能产生不利影响。然而,对于其他副作用,我们可能希望强制执行单线程程序顺序,以便用户不会遇到反直觉的行为。考虑一个日志记录效果。

@jax.jit
def f(x, y):
  log_value(x)
  log_value(y)
f(1, 2)

如果 log 正在修改一个全局列表,我们可能会期望在添加 y 之前先添加 x。为了获得更严格的效果,我们可能希望有一个选项来排序这些效果。

强制有序效果#

我们用来强制执行计算顺序的主要工具是 数据依赖。简单来说,如果一个函数 g 的输入是函数 f 的输出,那么 f 必须在 g 之前执行。

然而,我们可能会有像打印这样的副作用,它们完全没有输入,因此我们无法简单地对它们进行排序。因此,我们使用 tokens 作为一种向计算中注入人工数据依赖的手段。

什么是令牌?令牌只是一个可以在计算中传入和传出的虚拟值。通过在多个计算中传入和传出相同的令牌,我们强制它们必须按某种顺序发生。让我们以前面的打印示例为例,看看如果加入令牌会是什么样子:

@jax.jit
def f(token, x):
  token = jax.print(token, "hello")
  token = jax.print(token, "world")
  return token, x

如果我们重写 jax.print 使其接受并返回一个标记,我们现在已经将两个打印操作串联起来,因为第二个打印的输入依赖于第一个打印的输出。token 的实际值可以是任何东西,但实际上我们会看到,这些标记对用户来说是不可见的。

运行时令牌 vs. 编译器令牌#

在这里,我们将实际开始讨论实现细节。在实践中,我们将需要两种不同类型的令牌来序列化效果:每种令牌对应上述重新排序的来源之一。我们需要 运行时令牌 来序列化异步分派的副作用计算,并且我们需要 编译器令牌 来序列化计算中的效果。

在实践中,我们的计算将被重写为如下形式:

@jax.jit
def f(runtime_token, x):
  compiler_token = new_compiler_token()
  compiler_token = jax.print(compiler_token, "hello")
  compiler_token = jax.print(compiler_token, "world")
  return runtime_token, x

注意运行时令牌仅在JIT边界使用,而编译器令牌仅在编译代码中使用。编译器令牌在“降低”过程中创建(我们将Python代码转换为较低级别的表示形式,如HLO或StableHLO),但运行时令牌需要在Python中管理,因为它们在JIT函数中被传递进出。

此外,请注意运行时令牌与编译器令牌是“断开”的,这意味着它们之间没有数据依赖关系。这可能会带来潜在的危险,因为如果我们失去了两个调度函数调用体之间的数据依赖关系。然而,如果我们假设“严格执行”——即一个调度的函数只有在所有输入都准备好并且其所有输出将在同一时间准备好时才会开始执行——我们可以安全地创建一个新的编译器令牌并返回一个非输出依赖的运行时令牌。

管理运行时令牌#

为了代表用户管理运行时令牌,我们需要挂钩到 JAX 的调度机制中。每当我们调用一个 JIT 函数时,最终都会进入一个类似这样的函数:

def _execute(compiled_computation, *args):
  outputs = compiled_computation.execute(*args)
  return outputs

在这一点上,我们需要将运行时令牌“注入”到计算中,并从计算的输出中“提取”它们:

def _execute(compiled_computation, *args):
  runtime_token = get_runtime_token() # Grab global token
  runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_runtime_token(runtime_token) # Update global token
  return outputs

runtime_token 到底是什么?我们需要能够将其传递给 compiled_computation,这意味着它需要是某种数组(目前,因为在编译的 JAX 代码内部和外部没有共享的令牌表示)。实际上,我们可以使用 (0,) 形状的数组来最小化开销。

我们还需要考虑多设备使用场景,例如第一个示例中,我们首先在设备0上调用一个JIT编译的函数,然后在设备1上调用另一个函数。在这种情况下,我们需要将第一个计算返回的运行时令牌(存在于设备0上)复制到设备1,以便我们可以将其传递到第二个计算中。如果两个连续的计算共享同一个设备,则不需要此复制操作。

添加编译器标记#

当我们将Python代码降低到HLO或StableHLO时,我们需要在计算开始时创建一个令牌,并确保在需要排序的副作用计算中可用。副作用计算将把令牌作为输入并将其作为输出返回。

这个令牌线程的实现涉及升级 JAX 的降低机制,以自动完成这种簿记工作。主要挑战涉及处理高阶原语,如调用原语和控制流原语。我们不会在这篇设计笔记中详细讨论如何处理这些。

阻塞输出令牌#

为副作用计算添加对运行时和编译器令牌的支持对于排序很重要,但令牌还有另一个微妙的用例,即阻塞副作用计算。即使我们不希望副作用计算被排序,我们可能仍然希望等待其完成。目前我们有 jax.block_until_ready,它会等待直到未来值的结果准备好。然而,对于副作用计算,我们可能有一些没有返回值但仍在执行副作用的函数。以这里的一个简单例子为例:

@jax.jit
def f():
  jax.print("hello world")
  return
f() # Executed asynchronously

这个编译的计算没有显式的输入和输出。如果它是一个有序的打印效果,我们可以阻塞返回的运行时令牌。然而,当这是一个无序计算时,我们不做任何令牌线程。当我们没有输出值来调用 block_until_ready 时,我们如何等待 f() 完成执行?我们可以应用相同的令牌策略,只是我们只返回运行时令牌,不把它们作为输入。这将给我们一个值来阻塞,只有在 f() 完成执行后才会准备好。我们称这些令牌为 输出令牌。我们最终得到一个看起来像这样的函数:

@jax.jit
def f():
  jax.print("hello world")
  return new_runtime_token()
f() # Executed asynchronously

在底层,我们将以管理运行时令牌的相同方式管理输出令牌,但提供一种方法让用户阻塞当前的输出令牌集合。与运行时令牌不同,输出令牌需要是 设备特定的。考虑单一设备的使用场景:

@jax.jit
def f():
  jax.print("hello")

@jax.jit
def g():
  jax.print("world")

f()
g()

由于 f()g() 在同一设备上执行,阻塞 g() 的输出令牌实际上会阻塞 f(),因为(截至目前!)JAX 运行时不支持在同一设备上交错执行计算。当然,如果这一点发生变化,我们将不得不重新设计整个方案。

然而,考虑这两种设备使用情况:

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")

f()
g()

在这里,我们不想显式地顺序执行 f()g(),而是希望等待它们都完成。我们需要为 f()g() 各准备一个输出令牌,并且我们将阻塞在这两个令牌上:

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return new_runtime_token()

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return new_runtime_token()

t0 = f()
t1 = g()
block_until_ready((t0, t1))

因此,我们需要一个每个设备的输出令牌,以便在提供阻止副作用计算的能力的同时,避免在不同设备上进行序列化计算。我们最终得到了以下(近似)对JAX调度机制的更改:

def _execute(compiled_computation, *args):
  output_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_output_token(output_token, compiled_computation.device)
  return outputs

我们还需要公开一个函数,该函数会阻塞在输出令牌上:

def effects_barrier():
  output_token.block_until_ready()

需要注意的是,由于大多数 JAX 计算会返回一个值来阻塞,因此阻塞输出标记可能并不常见。然而,输出标记对于测试和性能分析非常有帮助,并且支持它们是好的,这样我们就能有一个一致且连贯的效果系统。

更多细节#

  • 上述所有令牌管理基础设施都将是 线程局部 的。这意味着每个用户线程将拥有自己独立的运行时令牌流。排序仅在用户线程级别得到保证。

  • 在实践中,我们为每个效果设置一个运行时令牌。该效果的不同实例将被排序。这是为了避免对可能彼此无关的效果计算进行排序。从技术上讲,这违背了我们最初强制执行单线程Python程序顺序的目标,但这是一个可以通过同时使用“效果”特定令牌和“全局”令牌来调节的权衡。