应用更新#
|
将更新应用到相应的参数。 |
|
通过Polyak平均法增量更新参数。 |
|
定期用新值更新所有参数。 |
应用更新#
- optax.apply_updates(params: optax.Params, updates: optax.Updates) optax.Params[来源]#
对相应的参数应用更新。
这是一个实用函数,它对一组参数进行更新,然后将更新后的参数返回给调用者。作为一个例子,更新可能是由一系列`GradientTransformations`变换的梯度。这个函数是为了方便而暴露的,但它只是添加更新和参数;你也可以手动对参数应用更新,使用jax.tree.map(例如,如果你想在应用之前以自定义方式操作更新)。
- Parameters:
参数 – 一棵参数树。
updates – 一个更新的树,树的结构和叶子节点的形状必须与params匹配。
- Returns:
更新参数,与 params 具有相同的结构、形状和类型。
增量更新#
- optax.incremental_update(new_tensors: optax.Params, old_tensors: optax.Params, step_size: chex.Numeric) optax.Params[来源]#
通过波列克平均法增量更新参数。
Polyak平均跟踪模型过去参数的(指数移动)平均值,以便在测试/评估时使用。
- Parameters:
new_tensors – 张量的最新值。
old_tensors – 张量值的移动平均。
step_size – 用于在每一步更新polyak平均值的步长。
- Returns:
更新的移动平均 step_size*new+(1-step_size)*old 的参数。
参考文献
[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)
定期更新#
- optax.periodic_update(new_tensors: optax.Params, old_tensors: optax.Params, steps: chex.Array, update_period: int) optax.Params[来源]#
定期使用新值更新所有参数。
模型参数的慢速复制,每 K 次实际更新更新一次,可以用于实现自我监督的形式(在有监督学习中),或用于稳定时间差分学习更新(在强化学习中)。
- Parameters:
new_tensors – 张量的最新值。
old_tensors – 模型参数的慢速复制。
步骤 – 在“在线”网络上的更新步骤数。
update_period – 每多少步更新一次“目标”网络。
- Returns:
模型参数的慢速复制,每隔update_period步骤更新一次。
参考文献
[Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)