From 0d0d76eb5be98d84af886297f22effcb4e26155c Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 28 Mar 2022 10:34:00 +0800 Subject: [PATCH] Fix bug while specifying target grad in high order gradient (#40940) * Fix bug while specifying target grad in high order gradient * add more unittest * add more unittest --- python/paddle/fluid/backward.py | 25 +++++-- .../tests/unittests/test_calc_gradient.py | 69 +++++++++++++++++++ 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 1637b33723b..0988f670955 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1052,7 +1052,8 @@ def _append_backward_ops_(block, callbacks=None, input_grad_names_set=None, op_path_dict=None, - distop_context=None): + distop_context=None, + rename_var_map=None): """ Create all grad ops, and insert them into given block @@ -1073,6 +1074,8 @@ def _append_backward_ops_(block, op_path_dict(dict): op_path_dict will be changed. key(int) block index val(list) the op path of block(index) + rename_var_map(dict): used to associate target_grad var name with first grad_op input name. + Only used in for high order gradient. """ if callbacks is not None: assert (isinstance(callbacks, (list, tuple))) @@ -1084,7 +1087,9 @@ def _append_backward_ops_(block, grad_op_descs = [] program = block.program - rename_var_map = {} + if rename_var_map is None: + rename_var_map = {} + assert isinstance(rename_var_map, dict) # add grad_op_desc by reversed ops for op in reversed(ops): @@ -1109,7 +1114,7 @@ def _append_backward_ops_(block, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) - # Build the mapping between the forward op and bacckward op (Only for auto parallel) + # Build the mapping between the forward op and backward op (Only for auto parallel) if distop_context is not None: for op_desc in grad_op_desc: assert op_desc.id() not in distop_context.grad_op_id_to_op_id @@ -1154,11 +1159,12 @@ def _append_backward_ops_(block, # But this strategy is not suited for while op for some control flow, # for example, for while op, the grads maybe generated in next loop. if input_grad_names_set is not None: + is_grad_name = lambda name: name.find(core.grad_var_suffix()) != -1 or name in input_grad_names_set is_append_grad = False for op_desc in grad_op_desc: input_grad_names = [ name for name in op_desc.input_arg_names() - if name.find(core.grad_var_suffix()) != -1 + if is_grad_name(name) ] # some code of gradient ops, like increment, are not very # standard, there is no @GRAD in these ops' inputs. @@ -1921,10 +1927,11 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): input_grad_names_set = set() target_grad_map = {} + rename_var_map = {} for i, grad in enumerate(target_gradients): target = targets[i] + grad_name = _append_grad_suffix_(target.name) if grad is None: - grad_name = _append_grad_suffix_(target.name) target_shape = target.name + '_shape' block.desc.append_op().copy_from( _create_op_desc_("shape", {'Input': [target.name]}, @@ -1949,11 +1956,14 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): target.name, grad.name)) target_grad_map[_append_grad_suffix_(target.name)] = grad.name input_grad_names_set.add(grad.name) + rename_var_map[grad_name] = grad.name # For double backward, input_grad_names is used for filter - # some non-used gradients op. + # some non-used gradients op. rename_var_map is used to + # associate target_grad var name with first grad_op input name. if prog._appending_grad_times == 1: input_grad_names_set = None + rename_var_map = {} for input in inputs: if input.block.program != prog: @@ -1980,7 +1990,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): no_grad_dict, grad_to_var, input_grad_names_set=input_grad_names_set, - op_path_dict=op_path_dict) + op_path_dict=op_path_dict, + rename_var_map=rename_var_map) # Because calc_gradient may be called multiple times, # we need rename the internal gradient variables so that they have diff --git a/python/paddle/fluid/tests/unittests/test_calc_gradient.py b/python/paddle/fluid/tests/unittests/test_calc_gradient.py index 339a66b0626..40e5abccb2d 100644 --- a/python/paddle/fluid/tests/unittests/test_calc_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_calc_gradient.py @@ -100,5 +100,74 @@ class TestGradientWithPrune(unittest.TestCase): self.assertTrue(np.array_equal(out[0], [2., 0., 0.])) +class TestDoubleGradient(unittest.TestCase): + def build_program(self): + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data('x', shape=[2, 2]) + x.stop_gradient = False + y = x * x + + v = paddle.ones([2, 2]) + v.stop_gradient = False + + grad_y = paddle.zeros_like(y) + grad_y.stop_gradient = False + grad_x = paddle.static.gradients(y, x, grad_y) + # test with single targets + jvp = paddle.static.gradients(grad_x, grad_y, v) + + return start_prog, main_prog, [grad_x, jvp] + + def test_calc_gradient(self): + start_prog, main_prog, fetch_list = self.build_program() + exe = paddle.static.Executor() + exe.run(start_prog) + ans = exe.run(main_prog, + feed={'x': np.ones([2, 2]).astype(np.float32)}, + fetch_list=fetch_list) + self.assertEqual(len(ans), 2) + self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) + self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]]) + + +class TestDoubleGradient2(unittest.TestCase): + def build_program(self): + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data('x', shape=[2, 2]) + x.stop_gradient = False + y = x * x + y2 = y + x + + v = paddle.ones([2, 2]) + v.stop_gradient = False + + grad_y = paddle.zeros_like(y) + grad_y.stop_gradient = False + grad_x = paddle.static.gradients(y, x, grad_y) + grad_x2 = paddle.static.gradients(y2, x, grad_y) + # test with multi targets + jvp = paddle.static.gradients([grad_x[0], grad_x2[0]], grad_y, + [v, v]) + + return start_prog, main_prog, [grad_x, jvp] + + def test_calc_gradient(self): + start_prog, main_prog, fetch_list = self.build_program() + exe = paddle.static.Executor() + exe.run(start_prog) + ans = exe.run(main_prog, + feed={'x': np.ones([2, 2]).astype(np.float32)}, + fetch_list=fetch_list) + self.assertEqual(len(ans), 2) + self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) + self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]]) + + if __name__ == "__main__": unittest.main() -- GitLab