未验证 提交 0d0d76eb 编写于 作者: A Aurelius84 提交者: GitHub

Fix bug while specifying target grad in high order gradient (#40940)

* Fix bug while specifying target grad in high order gradient

* add more unittest

* add more unittest
上级 3f4099ee
...@@ -1052,7 +1052,8 @@ def _append_backward_ops_(block, ...@@ -1052,7 +1052,8 @@ def _append_backward_ops_(block,
callbacks=None, callbacks=None,
input_grad_names_set=None, input_grad_names_set=None,
op_path_dict=None, op_path_dict=None,
distop_context=None): distop_context=None,
rename_var_map=None):
""" """
Create all grad ops, and insert them into given block Create all grad ops, and insert them into given block
...@@ -1073,6 +1074,8 @@ def _append_backward_ops_(block, ...@@ -1073,6 +1074,8 @@ def _append_backward_ops_(block,
op_path_dict(dict): op_path_dict will be changed. op_path_dict(dict): op_path_dict will be changed.
key(int) block index key(int) block index
val(list) the op path of block(index) val(list) the op path of block(index)
rename_var_map(dict): used to associate target_grad var name with first grad_op input name.
Only used in for high order gradient.
""" """
if callbacks is not None: if callbacks is not None:
assert (isinstance(callbacks, (list, tuple))) assert (isinstance(callbacks, (list, tuple)))
...@@ -1084,7 +1087,9 @@ def _append_backward_ops_(block, ...@@ -1084,7 +1087,9 @@ def _append_backward_ops_(block,
grad_op_descs = [] grad_op_descs = []
program = block.program program = block.program
if rename_var_map is None:
rename_var_map = {} rename_var_map = {}
assert isinstance(rename_var_map, dict)
# add grad_op_desc by reversed ops # add grad_op_desc by reversed ops
for op in reversed(ops): for op in reversed(ops):
...@@ -1109,7 +1114,7 @@ def _append_backward_ops_(block, ...@@ -1109,7 +1114,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Build the mapping between the forward op and bacckward op (Only for auto parallel) # Build the mapping between the forward op and backward op (Only for auto parallel)
if distop_context is not None: if distop_context is not None:
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id assert op_desc.id() not in distop_context.grad_op_id_to_op_id
...@@ -1154,11 +1159,12 @@ def _append_backward_ops_(block, ...@@ -1154,11 +1159,12 @@ def _append_backward_ops_(block,
# But this strategy is not suited for while op for some control flow, # But this strategy is not suited for while op for some control flow,
# for example, for while op, the grads maybe generated in next loop. # for example, for while op, the grads maybe generated in next loop.
if input_grad_names_set is not None: if input_grad_names_set is not None:
is_grad_name = lambda name: name.find(core.grad_var_suffix()) != -1 or name in input_grad_names_set
is_append_grad = False is_append_grad = False
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
input_grad_names = [ input_grad_names = [
name for name in op_desc.input_arg_names() name for name in op_desc.input_arg_names()
if name.find(core.grad_var_suffix()) != -1 if is_grad_name(name)
] ]
# some code of gradient ops, like increment, are not very # some code of gradient ops, like increment, are not very
# standard, there is no @GRAD in these ops' inputs. # standard, there is no @GRAD in these ops' inputs.
...@@ -1921,10 +1927,11 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1921,10 +1927,11 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
input_grad_names_set = set() input_grad_names_set = set()
target_grad_map = {} target_grad_map = {}
rename_var_map = {}
for i, grad in enumerate(target_gradients): for i, grad in enumerate(target_gradients):
target = targets[i] target = targets[i]
if grad is None:
grad_name = _append_grad_suffix_(target.name) grad_name = _append_grad_suffix_(target.name)
if grad is None:
target_shape = target.name + '_shape' target_shape = target.name + '_shape'
block.desc.append_op().copy_from( block.desc.append_op().copy_from(
_create_op_desc_("shape", {'Input': [target.name]}, _create_op_desc_("shape", {'Input': [target.name]},
...@@ -1949,11 +1956,14 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1949,11 +1956,14 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
target.name, grad.name)) target.name, grad.name))
target_grad_map[_append_grad_suffix_(target.name)] = grad.name target_grad_map[_append_grad_suffix_(target.name)] = grad.name
input_grad_names_set.add(grad.name) input_grad_names_set.add(grad.name)
rename_var_map[grad_name] = grad.name
# For double backward, input_grad_names is used for filter # For double backward, input_grad_names is used for filter
# some non-used gradients op. # some non-used gradients op. rename_var_map is used to
# associate target_grad var name with first grad_op input name.
if prog._appending_grad_times == 1: if prog._appending_grad_times == 1:
input_grad_names_set = None input_grad_names_set = None
rename_var_map = {}
for input in inputs: for input in inputs:
if input.block.program != prog: if input.block.program != prog:
...@@ -1980,7 +1990,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1980,7 +1990,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
input_grad_names_set=input_grad_names_set, input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict) op_path_dict=op_path_dict,
rename_var_map=rename_var_map)
# Because calc_gradient may be called multiple times, # Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have # we need rename the internal gradient variables so that they have
......
...@@ -100,5 +100,74 @@ class TestGradientWithPrune(unittest.TestCase): ...@@ -100,5 +100,74 @@ class TestGradientWithPrune(unittest.TestCase):
self.assertTrue(np.array_equal(out[0], [2., 0., 0.])) self.assertTrue(np.array_equal(out[0], [2., 0., 0.]))
class TestDoubleGradient(unittest.TestCase):
def build_program(self):
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data('x', shape=[2, 2])
x.stop_gradient = False
y = x * x
v = paddle.ones([2, 2])
v.stop_gradient = False
grad_y = paddle.zeros_like(y)
grad_y.stop_gradient = False
grad_x = paddle.static.gradients(y, x, grad_y)
# test with single targets
jvp = paddle.static.gradients(grad_x, grad_y, v)
return start_prog, main_prog, [grad_x, jvp]
def test_calc_gradient(self):
start_prog, main_prog, fetch_list = self.build_program()
exe = paddle.static.Executor()
exe.run(start_prog)
ans = exe.run(main_prog,
feed={'x': np.ones([2, 2]).astype(np.float32)},
fetch_list=fetch_list)
self.assertEqual(len(ans), 2)
self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]])
self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]])
class TestDoubleGradient2(unittest.TestCase):
def build_program(self):
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data('x', shape=[2, 2])
x.stop_gradient = False
y = x * x
y2 = y + x
v = paddle.ones([2, 2])
v.stop_gradient = False
grad_y = paddle.zeros_like(y)
grad_y.stop_gradient = False
grad_x = paddle.static.gradients(y, x, grad_y)
grad_x2 = paddle.static.gradients(y2, x, grad_y)
# test with multi targets
jvp = paddle.static.gradients([grad_x[0], grad_x2[0]], grad_y,
[v, v])
return start_prog, main_prog, [grad_x, jvp]
def test_calc_gradient(self):
start_prog, main_prog, fetch_list = self.build_program()
exe = paddle.static.Executor()
exe.run(start_prog)
ans = exe.run(main_prog,
feed={'x': np.ones([2, 2]).astype(np.float32)},
fetch_list=fetch_list)
self.assertEqual(len(ans), 2)
self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]])
self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册