torch.optim.lbfgs 的源代码
import torch
from .optimizer import Optimizer
__all__ = ['LBFGS']
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# 从 https://github.com/torch/optim/blob/master/polyinterp.lua 移植
# 计算插值区域的边界
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
# 代码用于最常见的情况:两点之间的三次插值
# 具有两个点的函数值和导数值
# 在这种情况下(x2 是最远的点)的解:
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1**2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
return min(max(min_pos, xmin_bound), xmax_bound)
else:
return (xmin_bound + xmax_bound) / 2.
def _strong_wolfe(obj_func,
x,
t,
d,
f,
g,
gtd,
c1=1e-4,
c2=0.9,
tolerance_change=1e-9,
max_ls=25):
# 从 https://github.com/torch/optim/blob/master/lswolfe.lua 移植
d_norm = d.abs().max()
g = g.clone(memory_format=torch.contiguous_format)
# 使用初始步长评估目标函数和梯度
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)
# 括住一个区间,该区间包含满足 Wolfe 准则的点
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# 检查条件
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break
if gtd_new >= 0:
bracket =