预测#

投影可以用来执行约束优化。 对集合 \(\mathcal{C}\) 的欧几里得投影是:

\[\text{proj}_{\mathcal{C}}(u) := \underset{v}{\text{argmin}} ~ ||u - v||^2_2 \textrm{ 受限于 } v \in \mathcal{C}.\]

例如,这里是一个我们如何将参数投影到非负正交的示例:

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)

可用的投影#

projection_box(tree, lower, upper)

投影到方框约束。

projection_hypercube(tree[, scale])

投影到(单位)超立方体。

projection_l1_ball(tree[, scale])

投影到l1球上。

projection_l1_sphere(tree[, scale])

投影到 l1 球面上。

projection_l2_ball(tree[, scale])

投影到l2球体。

projection_l2_sphere(tree[, scale])

投影到 l2 球面上。

projection_linf_ball(tree[, scale])

投影到l-infinity球面。

projection_non_negative(tree)

投影到非负象限。

projection_simplex(tree[, scale])

投影到一个单纯形上。

投影到一个盒子#

optax.projections.projection_box(tree: Any, lower: Any, upper: Any) Any[源]#

投影到盒子约束。

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{约束条件} \quad \text{下限} \le p \le \text{上限}\]

其中 \(x\) 是输入树。

Parameters:
  • tree – 要投影的树。

  • lower – 下界,一个标量或与tree具有相同结构的树。

  • upper – 上界,一个标量或与 tree 具有相同结构的树。

Returns:

投影树,具有与 tree 相同的结构。

投影到超立方体#

optax.projections.projection_hypercube(tree: Any, scale: Any = 1.0) Any[来源]#

投影到(单位)超立方体。

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{条件是} \quad 0 \le p \le \text{scale}\]

其中 \(x\) 是输入树。

默认情况下,我们投影到单位超立方体(scale=1.0)。

这是围绕 projection_box 的一个便利封装。

Parameters:
  • tree – 要投影的树。

  • scale – 超立方体的缩放,标量或树(默认值:1.0)。

Returns:

投影树,与 tree 具有相同结构。

投影到L1球面#

optax.projections.projection_l1_ball(tree: Any, scale: float = 1.0) Any[源]#

投影到l1球面上。

此函数解决以下约束优化问题,
其中 x 是输入树。

\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad ||y||_1 \le \text{scale}\]
Parameters:
  • tree – 要投影的树。

  • scale – 球的半径。

Returns:

投影树,具有与 tree 相同的结构。

示例

>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> print(tree_utils.tree_l1_norm(tree))
6.2
>>> new_tree = projections.projection_l1_ball(tree)
>>> print(tree_utils.tree_l1_norm(new_tree))
1.0000002

在版本 0.2.4 中添加。

投影到L1球面#

optax.projections.projection_l1_sphere(tree: Any, scale: float = 1.0) Any[来源]#

投影到 l1 球面。

此函数解决以下约束优化问题,其中 x 是输入树。

\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad ||y||_1 = \text{scale}\]
Parameters:
  • tree – 要投影的树。

  • scale – 球体的半径。

Returns:

投影树,与 tree 具有相同结构。

投影到L2球面#

optax.projections.projection_l2_ball(tree: Any, scale: float = 1.0) Any[来源]#

投影到l2球上。

这个函数解决了以下约束优化问题,x 是输入树。

\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{满足条件} \quad ||y||_2 \le \text{scale}\]
Parameters:
  • tree – 要投影的树。

  • scale – 球的半径。

Returns:

投影树,与 tree 具有相同结构。

在版本 0.2.4 中添加。

投影到L2球面#

optax.projections.projection_l2_sphere(tree: Any, scale: float = 1.0) Any[来源]#

投影到 l2 球面。

此函数解决以下约束优化问题,其中 x 是输入树。

\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad ||y||_2 = \text{value}\]
Parameters:
  • tree – 要投影的树。

  • scale – 球体的半径。

Returns:

投影树,与 tree 具有相同结构。

在版本 0.2.4 中添加。

投影到L-infinity球#

optax.projections.projection_linf_ball(tree: Any, scale: float = 1.0) Any[源]#

投影到l-infinity球体上。

此函数解决以下约束优化问题,
其中 x 是输入树。

\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{满足} \quad ||y||_{\infty} \le \text{scale}\]
Parameters:
  • tree – 要投影的树。

  • scale – 球的半径。

Returns:

投影树,与 tree 具有相同结构。

投影到非负正交象限#

optax.projections.projection_non_negative(tree: Any) Any[源]#

投影到非负正交区域。

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{条件是} \quad p \ge 0\]

其中 \(x\) 是输入树。

Parameters:

tree – 用于投影的树。

Returns:

投影树,与 tree 具有相同结构。

投影到一个单纯形#

optax.projections.projection_simplex(tree: Any, scale: chex.Numeric = 1.0) Any[源]#

投影到单纯形上。

这个函数解决了以下约束优化问题,x 是输入树。

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{目标是} \quad p \ge 0, p^\top 1 = \text{scale}\]

默认情况下,投影是 onto 概率单纯形(单位单纯形)。

Parameters:
  • tree – 要投影的树。

  • scale – 投影树应该总和的值(默认值:1.0)。

Returns:

投影树,具有与 tree 相同结构的树。

示例

这里是一个使用树的示例:

>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> print(tree_utils.tree_sum(tree))
6.2
>>> new_tree = projections.projection_simplex(tree)
>>> print(tree_utils.tree_sum(new_tree))
1.0000002

在版本 0.2.3 中添加。