组合优化器#
|
应用一系列可链接的更新转换。 |
|
应用一系列可命名的可链接更新转换。 |
|
将参数分区,并对每个子集应用不同的转换。 |
链#
- optax.chain(*args: 基础.渐变变换) 基础.渐变变换额外参数[来源]#
应用一系列可链接的更新转换。
此函数创建一个新的
optax.GradientTransformation(),按顺序应用一系列梯度转换。新的转换的init函数通过连接各个转换的状态构造优化器状态,而update函数按给定顺序应用更新。- Parameters:
*args – 任意数量的
transform-s ofGradientTransformation或GradientTransformationExtraArgs。- Returns:
A
GradientTransformationExtraArgs,通过链接输入变换创建。请注意,无论参数类型如何,生成的变换始终支持额外的参数。传递给返回变换的任何额外参数将仅传递给支持额外参数的链中的那些变换。
示例
一个缩放 -0.1 的 adam 更新的变换:
>>> import optax >>> transform1 = optax.scale_by_adam() >>> transform2 = optax.scale(-0.1) >>> chained_transform = optax.chain(transform1, transform2) >>> params = {'a': 1.0} >>> state = chained_transform.init(params) >>> updates = {'a': -0.5} >>> updates, new_state = chained_transform.update(updates, state, params)
链中的优化器可能需要额外的参数:
>>> import optax >>> opt1 = optax.scale(0.1) # scale incoming gradients >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` >>> chained_transform = optax.chain(opt1, opt2) >>> state = chained_transform.init(0.5) >>> extra_args = {"value": 1.0} >>> updates, new_state = chained_transform.update( ... 0.7, state, 0.7, **extra_args # extra args for all transforms ... )
- optax.named_chain(*args: tuple[str, 基础.渐变变换]) 基础.渐变变换额外参数[来源]#
应用一系列命名的可链式更新转换。
一个变体的
optax.chain(),允许为每个转换命名。Here the
argsare(name, transformation)pairs, constituted of a stringnameand an associated transformationtransformation. The gradient transformation must be an instance ofGradientTransformationorGradientTransformationExtraArgs.每个
name被用作相应转换在named_chain状态中的键。因此,具有给定name的转换状态可以轻松检索为opt_state[name]。- Parameters:
*args – an arbitrary number of
(name, transform)pairs, constituted of a stringnameand an associated transformationtransform. The latter is aGradientTransformationorGradientTransformationExtraArgs.- Returns:
一个单一的 (init_fn, update_fn) 元组。
示例
>>> import optax >>> opt1 = optax.scale(0.1) # scale incoming gradients >>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update` >>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2)) >>> state = chained_transform.init(0.5) >>> extra_args = {"value": 1.0} >>> updates, new_state = chained_transform.update( ... 0.7, state, 0.7, **extra_args # extra args for all transforms ... ) >>> tuple(new_state.keys()) == ("scale", "sgd") True
多重变换#
- optax.multi_transform(transforms: Mapping[Hashable, 基础.渐变变换], param_labels: base.PyTree | Callable[[base.PyTree], base.PyTree], *, mask_compatible_extra_args: bool = False) 基础.渐变变换额外参数[来源]#
对分区参数进行分割,并对每个子集应用不同的转换。
有时候您可能希望对不同的参数应用不同的变换。例如,您可能想对神经网络的权重应用Adam,但对偏置使用SGD。此功能允许您这样做。
- Parameters:
transforms – 从标签到变换的映射。每个变换仅会应用于具有相同标签的参数。
param_labels – 一个与参数/更新相同形状或作为前缀的PyTree(或一个根据参数作为输入返回一个的函数)。这个PyTree的叶子对应于变换的键(因此叶子上的值必须是键的一个子集)。
mask_compatible_extra_args – 是否将相同的掩码应用于与 params/updates 具有相同树结构的 extra_arg 字段。
- Returns:
一个
optax.GradientTransformationExtraArgs()实现了init和update函数。
示例
下面是一个应用Adam到权重,SGD到偏置的2层神经网络的例子:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def map_nested_fn(fn): ... '''Recursively apply `fn` to key-value pairs of a nested dict.''' ... def map_fn(nested_dict): ... return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) ... for k, v in nested_dict.items()} ... return map_fn >>> params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, ... 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} >>> gradients = jax.tree.map(jnp.ones_like, params) # dummy gradients >>> label_fn = map_nested_fn(lambda k, _: k) >>> tx = optax.partition( ... {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn) >>> state = tx.init(params) >>> updates, new_state = tx.update(gradients, state, params) >>> new_params = optax.apply_updates(params, updates)
您可以直接提供标签的
label_fn的 PyTree,而不是提供一个标签函数。这个 PyTree 也可以是参数 PyTree 的前缀。下面的 GAN 假代码对此进行了演示:>>> generator_params = ... >>> discriminator_params = ... >>> all_params = (generator_params, discriminator_params) >>> param_labels = ('generator', 'discriminator') >>> tx = optax.partition( >>> {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, >>> param_labels)
如果您希望不优化某些参数,可以将
optax.partition()用optax.masked()包裹起来。