未验证 提交 08a772cb 编写于 作者: M mapingshuo 提交者: GitHub

fix API param bug of recompute.backward() (#22582)

* fix API param bug of recompute.backward(), test=develop
上级 61fef975
......@@ -3850,12 +3850,12 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
no_grad_set=None)
program = cost.block.program
with framework.program_guard(program, None):
......@@ -3871,8 +3871,7 @@ class RecomputeOptimizer(Optimizer):
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None,
checkpoints=None):
callbacks=None):
"""
call append_backward with checkpoints.
......@@ -3906,12 +3905,12 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
no_grad_set=None)
print("Finished backward")
"""
......@@ -3958,12 +3957,12 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred])
params_grads = sgd.backward(
cost,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[fc_1, pred])
no_grad_set=None)
optimize_ops = sgd.apply_optimize(
cost, startup_program=None, params_grads=params_grads)
......@@ -3993,8 +3992,7 @@ class RecomputeOptimizer(Optimizer):
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
checkpoints=self._checkpoints)
no_grad_set=no_grad_set)
if grad_clip:
# TODO(guru4elephant): should add grad_clip for static graph
......
......@@ -791,8 +791,7 @@ class TestRecomputeOptimizer(unittest.TestCase):
mean_out,
startup_program=None,
parameter_list=None,
no_grad_set=None,
checkpoints=[b1_out])
no_grad_set=None)
# apply gradient
program = mean_out.block.program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册