预测#
投影可以用来执行约束优化。 对集合 \(\mathcal{C}\) 的欧几里得投影是:
例如,这里是一个我们如何将参数投影到非负正交的示例:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)
可用的投影#
|
投影到方框约束。 |
|
投影到(单位)超立方体。 |
|
投影到l1球上。 |
|
投影到 l1 球面上。 |
|
投影到l2球体。 |
|
投影到 l2 球面上。 |
|
投影到l-infinity球面。 |
|
投影到非负象限。 |
|
投影到一个单纯形上。 |
投影到一个盒子#
- optax.projections.projection_box(tree: Any, lower: Any, upper: Any) Any[源]#
投影到盒子约束。
\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{约束条件} \quad \text{下限} \le p \le \text{上限}\]其中 \(x\) 是输入树。
- Parameters:
tree – 要投影的树。
lower – 下界,一个标量或与
tree具有相同结构的树。upper – 上界,一个标量或与
tree具有相同结构的树。
- Returns:
投影树,具有与
tree相同的结构。
投影到超立方体#
- optax.projections.projection_hypercube(tree: Any, scale: Any = 1.0) Any[来源]#
投影到(单位)超立方体。
\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{条件是} \quad 0 \le p \le \text{scale}\]其中 \(x\) 是输入树。
默认情况下,我们投影到单位超立方体(scale=1.0)。
这是围绕
projection_box的一个便利封装。- Parameters:
tree – 要投影的树。
scale – 超立方体的缩放,标量或树(默认值:1.0)。
- Returns:
投影树,与
tree具有相同结构。
投影到L1球面#
- optax.projections.projection_l1_ball(tree: Any, scale: float = 1.0) Any[源]#
投影到l1球面上。
此函数解决以下约束优化问题,
其中x是输入树。\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad ||y||_1 \le \text{scale}\]- Parameters:
tree – 要投影的树。
scale – 球的半径。
- Returns:
投影树,具有与
tree相同的结构。
示例
>>> import jax.numpy as jnp >>> from optax import tree_utils, projections >>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> print(tree_utils.tree_l1_norm(tree)) 6.2 >>> new_tree = projections.projection_l1_ball(tree) >>> print(tree_utils.tree_l1_norm(new_tree)) 1.0000002
在版本 0.2.4 中添加。
投影到L1球面#
投影到L2球面#
投影到L2球面#
- optax.projections.projection_l2_sphere(tree: Any, scale: float = 1.0) Any[来源]#
投影到 l2 球面。
此函数解决以下约束优化问题,其中
x是输入树。\[\underset{y}{\text{argmin}} ~ ||x - y||_2^2 \quad \textrm{subject to} \quad ||y||_2 = \text{value}\]- Parameters:
tree – 要投影的树。
scale – 球体的半径。
- Returns:
投影树,与
tree具有相同结构。
在版本 0.2.4 中添加。
投影到L-infinity球#
投影到非负正交象限#
投影到一个单纯形#
- optax.projections.projection_simplex(tree: Any, scale: chex.Numeric = 1.0) Any[源]#
投影到单纯形上。
这个函数解决了以下约束优化问题,
x是输入树。\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{目标是} \quad p \ge 0, p^\top 1 = \text{scale}\]默认情况下,投影是 onto 概率单纯形(单位单纯形)。
- Parameters:
tree – 要投影的树。
scale – 投影树应该总和的值(默认值:1.0)。
- Returns:
投影树,具有与
tree相同结构的树。
示例
这里是一个使用树的示例:
>>> import jax.numpy as jnp >>> from optax import tree_utils, projections >>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} >>> print(tree_utils.tree_sum(tree)) 6.2 >>> new_tree = projections.projection_simplex(tree) >>> print(tree_utils.tree_sum(new_tree)) 1.0000002
在版本 0.2.3 中添加。