应用更新#

apply_updates(params, updates)

将更新应用到相应的参数。

incremental_update(new_tensors, old_tensors, ...)

通过Polyak平均法增量更新参数。

periodic_update(new_tensors, old_tensors, ...)

定期用新值更新所有参数。

应用更新#

optax.apply_updates(params: optax.Params, updates: optax.Updates) optax.Params[来源]#

对相应的参数应用更新。

这是一个实用函数,它对一组参数进行更新,然后将更新后的参数返回给调用者。作为一个例子,更新可能是由一系列`GradientTransformations`变换的梯度。这个函数是为了方便而暴露的,但它只是添加更新和参数;你也可以手动对参数应用更新,使用jax.tree.map(例如,如果你想在应用之前以自定义方式操作更新)。

Parameters:
  • 参数 – 一棵参数树。

  • updates – 一个更新的树,树的结构和叶子节点的形状必须与params匹配。

Returns:

更新参数,与 params 具有相同的结构、形状和类型。

增量更新#

optax.incremental_update(new_tensors: optax.Params, old_tensors: optax.Params, step_size: chex.Numeric) optax.Params[来源]#

通过波列克平均法增量更新参数。

Polyak平均跟踪模型过去参数的(指数移动)平均值,以便在测试/评估时使用。

Parameters:
  • new_tensors – 张量的最新值。

  • old_tensors – 张量值的移动平均。

  • step_size – 用于在每一步更新polyak平均值的步长。

Returns:

更新的移动平均 step_size*new+(1-step_size)*old 的参数。

参考文献

[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)

定期更新#

optax.periodic_update(new_tensors: optax.Params, old_tensors: optax.Params, steps: chex.Array, update_period: int) optax.Params[来源]#

定期使用新值更新所有参数。

模型参数的慢速复制,每 K 次实际更新更新一次,可以用于实现自我监督的形式(在有监督学习中),或用于稳定时间差分学习更新(在强化学习中)。

Parameters:
  • new_tensors – 张量的最新值。

  • old_tensors – 模型参数的慢速复制。

  • 步骤 – 在“在线”网络上的更新步骤数。

  • update_period – 每多少步更新一次“目标”网络。

Returns:

模型参数的慢速复制,每隔update_period步骤更新一次。

参考文献

[Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)