未验证 提交 6e37a2c0 编写于 作者: H Hui Zhang 提交者: GitHub

fix lbfgs error (#50820)

上级 db170b2b
...@@ -236,7 +236,7 @@ class LBFGS(Optimizer): ...@@ -236,7 +236,7 @@ class LBFGS(Optimizer):
with paddle.no_grad(): with paddle.no_grad():
# Make sure the closure is always called with grad enabled # Make sure the closure is always called with grad enabled
closure = paddle.set_grad_enabled(True)(closure) closure = paddle.enable_grad()(closure)
lr = self.lr lr = self.lr
max_iter = self.max_iter max_iter = self.max_iter
...@@ -376,7 +376,7 @@ class LBFGS(Optimizer): ...@@ -376,7 +376,7 @@ class LBFGS(Optimizer):
# no line search, simply move with fixed-step # no line search, simply move with fixed-step
self._add_grad(alpha, d) self._add_grad(alpha, d)
if n_iter != max_iter: if n_iter != max_iter:
with paddle.set_grad_enabled(True): with paddle.enable_grad():
loss = float(closure()) loss = float(closure())
flat_grad = self._gather_flat_grad() flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad opt_cond = flat_grad.abs().max() <= tolerance_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册