未验证 提交 0d0d76eb 编写于 作者: A Aurelius84 提交者: GitHub

Fix bug while specifying target grad in high order gradient (#40940)

* Fix bug while specifying target grad in high order gradient

* add more unittest

* add more unittest
上级 3f4099ee
......@@ -1052,7 +1052,8 @@ def _append_backward_ops_(block,
callbacks=None,
input_grad_names_set=None,
op_path_dict=None,
distop_context=None):
distop_context=None,
rename_var_map=None):
"""
Create all grad ops, and insert them into given block
......@@ -1073,6 +1074,8 @@ def _append_backward_ops_(block,
op_path_dict(dict): op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
rename_var_map(dict): used to associate target_grad var name with first grad_op input name.
Only used in for high order gradient.
"""
if callbacks is not None:
assert (isinstance(callbacks, (list, tuple)))
......@@ -1084,7 +1087,9 @@ def _append_backward_ops_(block,
grad_op_descs = []
program = block.program
if rename_var_map is None:
rename_var_map = {}
assert isinstance(rename_var_map, dict)
# add grad_op_desc by reversed ops
for op in reversed(ops):
......@@ -1109,7 +1114,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Build the mapping between the forward op and bacckward op (Only for auto parallel)
# Build the mapping between the forward op and backward op (Only for auto parallel)
if distop_context is not None:
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.grad_op_id_to_op_id
......@@ -1154,11 +1159,12 @@ def _append_backward_ops_(block,
# But this strategy is not suited for while op for some control flow,
# for example, for while op, the grads maybe generated in next loop.
if input_grad_names_set is not None:
is_grad_name = lambda name: name.find(core.grad_var_suffix()) != -1 or name in input_grad_names_set
is_append_grad = False
for op_desc in grad_op_desc:
input_grad_names = [
name for name in op_desc.input_arg_names()
if name.find(core.grad_var_suffix()) != -1
if is_grad_name(name)
]
# some code of gradient ops, like increment, are not very
# standard, there is no @GRAD in these ops' inputs.
......@@ -1921,10 +1927,11 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
input_grad_names_set = set()
target_grad_map = {}
rename_var_map = {}
for i, grad in enumerate(target_gradients):
target = targets[i]
if grad is None:
grad_name = _append_grad_suffix_(target.name)
if grad is None:
target_shape = target.name + '_shape'
block.desc.append_op().copy_from(
_create_op_desc_("shape", {'Input': [target.name]},
......@@ -1949,11 +1956,14 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
target.name, grad.name))
target_grad_map[_append_grad_suffix_(target.name)] = grad.name
input_grad_names_set.add(grad.name)
rename_var_map[grad_name] = grad.name
# For double backward, input_grad_names is used for filter
# some non-used gradients op.
# some non-used gradients op. rename_var_map is used to
# associate target_grad var name with first grad_op input name.
if prog._appending_grad_times == 1:
input_grad_names_set = None
rename_var_map = {}
for input in inputs:
if input.block.program != prog:
......@@ -1980,7 +1990,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
no_grad_dict,
grad_to_var,
input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict)
op_path_dict=op_path_dict,
rename_var_map=rename_var_map)
# Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have
......
......@@ -100,5 +100,74 @@ class TestGradientWithPrune(unittest.TestCase):
self.assertTrue(np.array_equal(out[0], [2., 0., 0.]))
class TestDoubleGradient(unittest.TestCase):
def build_program(self):
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data('x', shape=[2, 2])
x.stop_gradient = False
y = x * x
v = paddle.ones([2, 2])
v.stop_gradient = False
grad_y = paddle.zeros_like(y)
grad_y.stop_gradient = False
grad_x = paddle.static.gradients(y, x, grad_y)
# test with single targets
jvp = paddle.static.gradients(grad_x, grad_y, v)
return start_prog, main_prog, [grad_x, jvp]
def test_calc_gradient(self):
start_prog, main_prog, fetch_list = self.build_program()
exe = paddle.static.Executor()
exe.run(start_prog)
ans = exe.run(main_prog,
feed={'x': np.ones([2, 2]).astype(np.float32)},
fetch_list=fetch_list)
self.assertEqual(len(ans), 2)
self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]])
self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]])
class TestDoubleGradient2(unittest.TestCase):
def build_program(self):
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data('x', shape=[2, 2])
x.stop_gradient = False
y = x * x
y2 = y + x
v = paddle.ones([2, 2])
v.stop_gradient = False
grad_y = paddle.zeros_like(y)
grad_y.stop_gradient = False
grad_x = paddle.static.gradients(y, x, grad_y)
grad_x2 = paddle.static.gradients(y2, x, grad_y)
# test with multi targets
jvp = paddle.static.gradients([grad_x[0], grad_x2[0]], grad_y,
[v, v])
return start_prog, main_prog, [grad_x, jvp]
def test_calc_gradient(self):
start_prog, main_prog, fetch_list = self.build_program()
exe = paddle.static.Executor()
exe.run(start_prog)
ans = exe.run(main_prog,
feed={'x': np.ones([2, 2]).astype(np.float32)},
fetch_list=fetch_list)
self.assertEqual(len(ans), 2)
self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]])
self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册