优化器包装器#

apply_if_finite(inner, max_consecutive_errors)

一个包装优化器的函数,使其对少量的NaNs或Infs具有鲁棒性。

ApplyIfFiniteState(notfinite_count, ...)

apply_if_finite返回的GradientTransformation的状态。

flatten(inner)

扁平化参数和梯度,以初始化和更新内部变换。

lookahead(fast_optimizer, sync_period, ...)

前瞻优化器。

LookaheadParams(fast, slow)

持有用于前瞻优化器的慢速和快速参数对。

LookaheadState(fast_state, steps_since_sync)

lookahead返回的GradientTransformation的状态。

masked(inner, mask, *[, ...])

掩码更新,使得只有部分被转换,其余的直接通过。

MaskedState(inner_state)

维护掩码转换的内部变换状态。

MultiSteps(opt, every_k_schedule[, ...])

一个优化器包装器,用于在多个步骤中累积梯度。

MultiStepsState(mini_step, gradient_step, ...)

MultiSteps 返回的 GradientTransformation 的状态。

ShouldSkipUpdateFunction(*args, **kwargs)

skip_large_updates(updates, gradient_step, ...)

如果updates的全局范数平方足够小,则返回True。

skip_not_finite(updates, gradient_step, params)

如果任何一个updates包含无穷大或NaN,则返回True。

如果有限则申请#

optax.apply_if_finite(inner: optax.GradientTransformation, max_consecutive_errors: int) optax.GradientTransformation[来源]#

一个将优化器包装起来的函数,使其对一些NaN或Infs具有鲁棒性。

这个函数的目的是防止在梯度包含NaNs或Infs时发生任何优化。也就是说,当在梯度中检测到NaN或Inf时,包装的优化器会忽略该梯度更新。如果在给定数量的更新后NaNs或Infs仍然存在,则包装的优化器会放弃并接受该更新。

Parameters:
  • inner – 要包装的内部转换。

  • max_consecutive_errors – 最大连续梯度更新次数,其中包含NaNs或Infs,包装的优化器将忽略这些更新。在忽略这么多次更新后,优化器将放弃并接受。

Returns:

新的 optax.GradientTransformationExtraArgs.

class optax.ApplyIfFiniteState(notfinite_count: Any, last_finite: Any, total_notfinite: Any, inner_state: Any)[来源]#

apply_if_finite返回的GradientTransformation的状态。

notfinite_count#

包含Inf或NaN的连续梯度更新的数量。当进行一次没有Inf或NaN的梯度更新时,此数字会重置为0。

Type:

任何

last_finite#

最后的梯度更新是否包含Inf或NaN。

Type:

任意

total_notfinite#

自该优化器初始化以来,包含Inf或NaN的总梯度更新次数。此数字从不重置。 inner_state:内层 GradientTransformation 的状态。

Type:

任何

扁平化#

optax.flatten(inner: 基础.渐变变换) 基础.渐变变换额外参数[来源]#

扁平化参数和梯度以用于内变换的初始化和更新。

这可以减少对许多小变量进行大量计算的开销,但代价是略微增加了内存使用。

Parameters:

inner – 内部转换,用于扁平化输入。

Returns:

新的 optax.GradientTransformationExtraArgs

前瞻#

optax.lookahead(fast_optimizer: optax.GradientTransformation, sync_period: int, slow_step_size: float, reset_state: bool = False) optax.GradientTransformation[来源]#

前瞻优化器。

使用快速优化器执行步骤,并定期更新一组慢参数。可选地在通过调用快速优化器的初始化函数进行同步后重置快速优化器状态。

由前瞻优化器返回的更新在应用之前不应被修改,否则快速参数和慢速参数将无法正确同步。

Parameters:
  • fast_optimizer – 在前瞻的内部循环中使用的优化器。

  • sync_period – 在同步参数之前要进行的快速优化器步骤的数量。必须 >= 1。

  • slow_step_size – 慢参数更新的步长。

  • reset_state – 是否在每次同步后重置快速优化器的优化器状态。

Returns:

A optax.GradientTransformation 具有初始化和更新功能。传递给更新功能的更新应仅使用快速预览参数计算。

参考文献

张等人,Lookahead Optimizer: k steps forward, 1 step back,2019

class optax.LookaheadParams(fast: optax.Params, slow: optax.Params)[来源]#

保存一对用于前瞻优化器的慢速和快速参数。

梯度应该始终使用快速参数进行计算。慢参数应用于测试和推断,因为它们具有更好的泛化能力。有关详细讨论,请参见参考文献。

fast#

快速参数。

Type:

base.Params

slow#

慢参数。

Type:

base.Params

参考文献

Zhang et al, Lookahead Optimizer: k steps forward, 1 step back, 2019

class optax.LookaheadState(fast_state: base.OptState, steps_since_sync: jnp.ndarray)[来源]#

lookahead返回的GradientTransformation的状态。

fast_state#

快速优化器的优化器状态。

Type:

base.OptState

steps_since_sync#

自从慢速和快速参数同步以来,采取的快速优化器步骤的数量。

Type:

jnp.ndarray

掩码更新#

optax.masked(inner: 基础.渐变变换, mask: base.PyTree | Callable[[base.Params], base.PyTree], *, mask_compatible_extra_args: bool = False) 基础.渐变变换额外参数[来源]#

掩码更新,使得只有部分被转换,其余部分被传递。

例如,通常会跳过BatchNorm缩放和所有偏置参数的权重衰减。因为在许多网络中,这些是唯一的1D参数,您可以例如创建一个掩码函数来将它们屏蔽,如下所示:

mask_fn = lambda p: jax.tree.map(lambda x: x.ndim != 1, p)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)

您可以选择提前创建掩码 pytree:

mask = jax.tree.map(lambda x: x.ndim != 1, params)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)

对于 inner 转换,状态将仅对具有 True 的掩码值的参数进行存储。

请注意,当使用 tree_map_params 时,可能需要传递参数 is_leaf=lambda v: isinstance(v, optax.MaskedNode),如果树映射需要额外的参数并且形状与原始输入树相同。

Parameters:
  • inner – 内部转换以进行遮罩。

  • mask – 一个与params PyTree具有相同结构(或是其前缀)的PyTree,或者是一个可调用的对象,它根据params/updates返回这样的pytree。叶子应该是布尔值,True表示您希望应用转换的叶子/子树,而False表示您希望跳过的部分。对于梯度转换来说,mask必须是静态的,以便能够进行jit编译。

  • mask_compatible_extra_args – 是否对具有与params/updates相同树结构的extra_arg字段应用相同的掩码。

Returns:

新的 optax.GradientTransformationExtraArgs 包装 inner

class optax.MaskedState(inner_state: Any)[来源]#

维护掩码变换的内部变换状态。

多步骤更新#

class optax.MultiSteps(opt: optax.GradientTransformation, every_k_schedule: int | Callable[[TypeAliasForwardRef('chex.Array')], TypeAliasForwardRef('chex.Array')], use_grad_mean: bool = True, should_skip_update_fn: 应跳过更新功能 | None = None)[来源]#

一个优化器包装器,用于在多个步骤中累积梯度。

此包装器将传递给其 update 函数的更新在连续步骤中收集起来,直到达到给定数量的计划步骤。在这些中间步骤中,优化器返回的值是与作为输入传递的更新形状相同的零树。

一旦达到预定数量的中间“微步骤”,当前时间累积的梯度将传递给包装优化器的更新函数(内部优化器的状态适当更新),然后返回给调用者。包装器的累积梯度随后被重置为零,并且该过程重新开始。

每次梯度更新的迷你步骤数由一个函数控制,并且在训练过程中可以变化,这也允许在训练过程中改变批量大小。

class optax.MultiStepsState(mini_step: chex.Array, gradient_step: chex.Array, inner_opt_state: Any, acc_grads: Any, skip_state: chex.ArrayTree = ())[来源]#

MultiSteps返回的GradientTransformation的状态。

mini_step#

当前的小步计数器。在更新时,这个计数器要么增加1,要么重置为0。

Type:

chex.Array

gradient_step#

梯度步骤计数器。仅在积累了足够的小步骤后才会增加。

Type:

chex.Array

inner_opt_state#

包装优化器的状态。

Type:

任何

acc_grads#

多个小步骤中的累积梯度。

Type:

任何

skip_state#

一个任意的 py 树。这在传递 should_skip_update_fnMultiSteps 时是唯一相关的。

Type:

chex.ArrayTree

class optax.ShouldSkipUpdateFunction(*args, **kwargs)[来源]#
optax.skip_large_updates(updates: optax.Updates, gradient_step: chex.Array, params: TypeAliasForwardRef('optax.Params') | None, max_squared_norm: float) tuple[TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.ArrayTree')][来源]#

如果updates的全局范数平方足够小,则返回True。

Parameters:
  • 更新 – 请参见 ShouldSkipUpdateFunction.

  • gradient_step – 请参见 ShouldSkipUpdateFunction.

  • params – 参见 ShouldSkipUpdateFunction

  • max_squared_norm – 更新中可以接受的最大平方范数。

Returns:

  • 第一个元素是布尔类型的标量数组。

  • 第二个元素是一个字典,包含以下键: - should_skip: 当 ||updates||^2 大于 max_squared_norm 时。 - norm_squared: updates 的总体平方范数。

Return type:

一个元组

optax.skip_not_finite(updates: optax.Updates, gradient_step: chex.Array, params: TypeAliasForwardRef('optax.Params') | None) tuple[TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.ArrayTree')][来源]#

如果任何一个 updates 包含 inf 或 NaN,则返回 True。

Parameters:
  • 更新 – 请参见 ShouldSkipUpdateFunction

  • gradient_step – 参见 ShouldSkipUpdateFunction

  • 参数 – 请参见 ShouldSkipUpdateFunction

Returns:

  • 第一个元素是一个布尔型的标量数组。

  • 第二个元素是一个字典,键包括: - should_skip: 当且仅当 updates 包含无穷大或 NaN 时为 True。 - num_not_finite: 在 updates 中找到的无穷大和 NaN 的总数。

Return type:

一个元组