控制变量#
控制增量协方差方法。 |
|
|
使用控制变量获得雅可比。 |
|
移动平均基线。 |
控制增量方法#
- optax.monte_carlo.control_delta_method(function: Callable[[chex.Array], float]) ControlVariate[来源]#
控制 delta 协变方法。
- Control variate obtained by performing a second order Taylor expansion
在输入分布均值下的成本函数 f。
仅对高斯随机变量实现。
有关详细信息,请参见: https://icml.cc/2012/papers/687.pdf
- Parameters:
函数 – 用于计算控制变量的函数。该函数接收一个参数(来自分布的样本)并返回一个浮点值。
- Returns:
一个由三个函数组成的元组,用于计算控制变量、控制变量的期望值,并更新控制变量状态。
自版本 0.2.4 起已被弃用: 此函数将在 0.3.0 中被移除
控制变量雅可比#
- optax.monte_carlo.control_variates_jacobians(function: Callable[[chex.Array], float], control_variate_from_function: Callable[[Callable[[chex.Array], float]], ControlVariate], grad_estimator: Callable[..., jnp.ndarray], params: base.Params, dist_builder: Callable[..., Any], rng: chex.PRNGKey, num_samples: int, control_variate_state: CvState = None, estimate_cv_coeffs: bool = False, estimate_cv_coeffs_num_samples: int = 20) tuple[Sequence[chex.Array], CvState][来源]#
使用控制变量获得雅可比。
我们将单独计算每个项。第一个项将使用随机梯度估计。第二个项将使用蒙特卡洛估计和自动微分来计算 nabla_{theta} h(x; theta)。第三个项将使用自动微分来计算,因为我们将自己限制在控制变元中,这些变元以封闭形式计算这个期望。
此函数在计算控制变量系数之前更新控制变量的状态(一次)。
- Parameters:
function – 要估计 grads_{params} E_dist f(x) 的函数 f(x)。 该函数接受一个参数(来自分布的样本)并返回一个浮点值。
control_variate_from_function – 用于减少方差的控制变量。请参见control_delta_method和moving_avg_baseline的示例。
grad_estimator – 要用于计算梯度的梯度估计器。 注意,并非所有控制变量都会对所有 估计器减少方差。例如,moving_avg_baseline 对于测量值估计器或路径估计器没有任何影响。
params – 一个 jnp 数组的元组。构造分布的参数,以及我们希望计算雅可比矩阵的参数。
dist_builder – 一个构造函数,基于参数中指定的输入参数构建分布。 dist_builder(params) 应该返回一个有效的分布。
rng – 一个 PRNGKey 密钥。
num_samples – 整数,计算梯度所使用的样本数量。
control_variate_state – 控制变量状态。这用于保持状态的控制变量(例如移动平均基线)。
estimate_cv_coeffs – 布尔值。是否通过 estimate_control_variate_coefficients 来估计最优控制变量系数。
estimate_cv_coeffs_num_samples – 用于估计最佳系数的样本数量。这些需要是新的样本,以确保目标是无偏的。
- Returns:
一个大小为 params 的元组,每个元素是 num_samples x param.shape 雅可比向量,包含每个样本获得的梯度估计。 该向量的均值是关于参数的梯度,可以用于学习。整个雅可比向量可用于评估 估计器方差。
更新后的 CV 状态。
- Return type:
一个大小为二的元组
自版本 0.2.4 起已被弃用: 此函数将在 0.3.0 中被移除
移动平均基线#
- optax.monte_carlo.moving_avg_baseline(function: Callable[[chex.Array], float], decay: float = 0.99, zero_debias: bool = True, use_decay_early_training_heuristic=True) ControlVariate[来源]#
移动平均基线。
它对路径估计或测度值估计器没有影响。
- Parameters:
function – 用于计算控制变量的函数。该函数接受一个参数(来自分布的样本),并返回一个浮点值。
decay – 移动平均的衰减率。
zero_debias – 是否使用零去偏移的移动平均。
use_decay_early_training_heuristic – 是否使用启发式方法,该方法在训练早期基于 min(decay, (1.0 + i) / (10.0 + i)) 覆盖衰减值。这样可以稳定训练,并且是从 Tensorflow 代码库中改编的。
- Returns:
一个由三个函数组成的元组,用于计算控制变量、控制变量的期望值,并更新控制变量状态。
自版本 0.2.4 起已被弃用: 此函数将在 0.3.0 中被移除