未验证 提交 435fc4f0 编写于 作者: L liym27 提交者: GitHub

[while grad]Support pruning op in find_op_path about while sub-block when...

[while grad]Support pruning op in find_op_path about while sub-block when appending backward (#25330)

Prune OPs which are not related with loss in while sub-block when constructing backward OP path.
上级 aaa7cbd5
...@@ -797,32 +797,51 @@ def _append_backward_ops_with_checkpoints_( ...@@ -797,32 +797,51 @@ def _append_backward_ops_with_checkpoints_(
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set): def _get_sub_block_path(sub_block,
sub_block_op_desc,
no_grad_set,
op_path_dict,
sub_block_target_names=None):
""" """
Get output vars in subblock which will be assigned to parent block. Get output vars in subblock which will be assigned to parent block.
It is used to find the grad path in subblock It is used to find the grad path in subblock.
Args:
sub_block(Block): The sub-block in which to get op path.
sub_block_op_desc: The op desc of the sub-block op such as 'while', 'conditional_block' and 'recurrent'.
no_grad_set(set): The set of no grad var name. no_grad_set will be changed.
op_path_dict(dict): op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
sub_block_target_names(set): Target var names of sub-block.
Return:
The forward op path of sub-block corresponding to backward op.
""" """
assert sub_block_op_desc.has_attr( assert sub_block_op_desc.has_attr(
"sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id( "sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id(
"sub_block") "sub_block")
# TODO(huihuangzheng): add support for recurrent op and while op assert isinstance(sub_block_target_names, (set, type(None)))
if sub_block_op_desc.type == "conditional_block":
sub_outputs = [] if sub_block_target_names is None:
sub_assign_to_out_ops = [] sub_block_target_names = sub_block_op_desc.output_arg_names
for var in sub_block_op_desc.output_arg_names:
# TODO(huihuangzheng): add support for recurrent op.
if sub_block_op_desc.type in ["conditional_block", "while"]:
# Step1: get the output vars in sub-block
sub_outputs = [
sub_block._var_recursive(var) for var in sub_block_target_names
]
for var in sub_block_target_names:
for op_desc in sub_block.ops: for op_desc in sub_block.ops:
if op_desc.type == "assign" and var in op_desc.output_arg_names: if var in op_desc.output_arg_names:
sub_assign_to_out_ops.append(op_desc)
for name in op_desc.input_arg_names: for name in op_desc.input_arg_names:
if sub_block.has_var(name): sub_outputs.append(sub_block._var_recursive(name))
sub_outputs.append(sub_block.var(name))
# Step2: find op path of sub-block
is_while = sub_block_op_desc.type in ["while"]
sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [], sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [],
no_grad_set) no_grad_set, op_path_dict, is_while)
# TODO better way than finding in list
for op_desc in sub_assign_to_out_ops:
if op_desc not in sub_block_op_path:
sub_block_op_path.append(op_desc)
return sub_block_op_path return sub_block_op_path
return sub_block.ops return sub_block.ops
...@@ -846,7 +865,8 @@ def _append_backward_ops_(block, ...@@ -846,7 +865,8 @@ def _append_backward_ops_(block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
callbacks=None, callbacks=None,
input_grad_names_set=None): input_grad_names_set=None,
op_path_dict=None):
""" """
Create all grad ops, and insert them into given block Create all grad ops, and insert them into given block
...@@ -864,6 +884,9 @@ def _append_backward_ops_(block, ...@@ -864,6 +884,9 @@ def _append_backward_ops_(block,
input_grad_names_set(set): this set is used to store the gradients' name which is input_grad_names_set(set): this set is used to store the gradients' name which is
generated by backward ops, and input_grad_names_set can help to prune the unnecessary generated by backward ops, and input_grad_names_set can help to prune the unnecessary
backward ops. backward ops.
op_path_dict(dict): op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
""" """
if callbacks is not None: if callbacks is not None:
assert (isinstance(callbacks, list)) assert (isinstance(callbacks, list))
...@@ -888,11 +911,10 @@ def _append_backward_ops_(block, ...@@ -888,11 +911,10 @@ def _append_backward_ops_(block,
# see follwing comments for why set None here. # see follwing comments for why set None here.
pre_input_grad_names_set = copy.copy(input_grad_names_set) pre_input_grad_names_set = copy.copy(input_grad_names_set)
input_grad_names_set = None input_grad_names_set = None
sub_block_path = _get_sub_block_path(sub_block, op, sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
no_grad_dict[sub_block.idx])
_append_backward_ops_(sub_block, sub_block_path, grad_sub_block, _append_backward_ops_(sub_block, sub_block_path, grad_sub_block,
no_grad_dict, grad_to_var, callbacks, no_grad_dict, grad_to_var, callbacks,
input_grad_names_set) input_grad_names_set, op_path_dict)
input_grad_names_set = pre_input_grad_names_set input_grad_names_set = pre_input_grad_names_set
program._rollback() program._rollback()
...@@ -1358,7 +1380,10 @@ def append_backward(loss, ...@@ -1358,7 +1380,10 @@ def append_backward(loss,
block_no_grad_set = set( block_no_grad_set = set(
map(_strip_grad_suffix_, no_grad_dict[block_idx])) map(_strip_grad_suffix_, no_grad_dict[block_idx]))
op_path = _find_op_path_(block, [loss], [], block_no_grad_set)
op_path_dict = dict()
op_path = _find_op_path_(block, [loss], [], block_no_grad_set,
op_path_dict)
no_grad_vars = _find_no_grad_vars(block, op_path, [loss], no_grad_vars = _find_no_grad_vars(block, op_path, [loss],
block_no_grad_set) block_no_grad_set)
...@@ -1371,13 +1396,13 @@ def append_backward(loss, ...@@ -1371,13 +1396,13 @@ def append_backward(loss,
# For double backward, input_grad_names is used for filtering # For double backward, input_grad_names is used for filtering
# some non-used gradients op(s). # some non-used gradients op(s).
# Todo(liym27): need a better design. # TODO(liym27): need a better design.
# not support double grad in control flow sub-block now. # not support double grad in control flow sub-block now.
if not is_in_control_flow: if not is_in_control_flow:
if program._appending_grad_times > 1: if program._appending_grad_times > 1:
input_grad_names_set = set([_append_grad_suffix_(loss.name)]) input_grad_names_set = set([_append_grad_suffix_(loss.name)])
# Todo: support _append_backward_ops_with_checkpoints_ in # TODO: support _append_backward_ops_with_checkpoints_ in
# sub-block (control flow) # sub-block (control flow)
if checkpoints != None and \ if checkpoints != None and \
isinstance(checkpoints, list) and \ isinstance(checkpoints, list) and \
...@@ -1400,7 +1425,8 @@ def append_backward(loss, ...@@ -1400,7 +1425,8 @@ def append_backward(loss,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
callbacks, callbacks,
input_grad_names_set=input_grad_names_set) input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict)
grad_info_map = dict() grad_info_map = dict()
...@@ -1508,13 +1534,14 @@ def _get_output_names(cur_block, targets): ...@@ -1508,13 +1534,14 @@ def _get_output_names(cur_block, targets):
""" """
block = targets[0].block if targets else cur_block block = targets[0].block if targets else cur_block
prog = cur_block.program
if _is_ancestor_block(block, cur_block):
return set()
current_output_names = set([out.name for out in targets]) current_output_names = set([out.name for out in targets])
# if `cur_block` is an ancestor of `targets[0].block`, run while loop # 1. If `targets` in cur_block or the ancestral block of `cur_block`
if block.idx == cur_block.idx or _is_ancestor_block(block, cur_block):
return current_output_names
# 2. If `cur_block` is an ancestor of `targets[0].block`, run while loop
prog = cur_block.program
while block.idx != cur_block.idx: while block.idx != cur_block.idx:
assert block.parent_idx != -1 assert block.parent_idx != -1
parent_block = prog.block(block.parent_idx) parent_block = prog.block(block.parent_idx)
...@@ -1554,12 +1581,32 @@ def _find_no_grad_vars(block, op_path, targets, no_grad_set): ...@@ -1554,12 +1581,32 @@ def _find_no_grad_vars(block, op_path, targets, no_grad_set):
return set(no_grad_var) return set(no_grad_var)
def _find_op_path_(block, outputs, inputs, no_grad_set): def _find_op_path_(block,
targets,
inputs,
no_grad_set,
op_path_dict=None,
is_while=False):
""" """
no_grad_set will also be changed It is used to find the grad path in `block`.
Args:
block(Block): The block in which to get op path.
targets(list[Variable]): The target variables.
inputs(list[Variable]): The input variables.
no_grad_set(set): The set of no grad var name. no_grad_set will be changed.
op_path_dict(dict): op_path_dict will be changed. op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
is_while(bool): Whether or not `block` is while block
Return:
The forward op path of block corresponding to backward op.
""" """
input_names = set([inp.name for inp in inputs]) input_names = set([inp.name for inp in inputs])
output_names = _get_output_names(block, outputs) output_names = _get_output_names(block, targets)
if op_path_dict is None:
op_path_dict = dict()
relevant_op_flags = [True] * len(block.ops) relevant_op_flags = [True] * len(block.ops)
...@@ -1576,6 +1623,15 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): ...@@ -1576,6 +1623,15 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
relevant_op_flags[i] = False relevant_op_flags[i] = False
for i, op in reversed(list(enumerate(block.ops))): for i, op in reversed(list(enumerate(block.ops))):
if op.has_attr("sub_block"):
sub_block_id = op._block_attr_id("sub_block")
sub_block = block.program.block(sub_block_id)
sub_block_target_names = output_names & set(op.output_arg_names)
sub_block_path = _get_sub_block_path(sub_block, op,
set(), op_path_dict,
sub_block_target_names)
op_path_dict[sub_block_id] = sub_block_path
if _some_in_set_( if _some_in_set_(
op.desc.output_arg_names(), op.desc.output_arg_names(),
output_names) and core.has_non_empty_grad_op_maker(op.type): output_names) and core.has_non_empty_grad_op_maker(op.type):
...@@ -1585,6 +1641,14 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): ...@@ -1585,6 +1641,14 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
else: else:
relevant_op_flags[i] = False relevant_op_flags[i] = False
if is_while:
# If block is while block, dealing with op specifically again.
# TODO(liym27): Consider special types of ops.
for i, op in reversed(list(enumerate(block.ops))):
if relevant_op_flags[i] == False \
and _some_in_set_(op.desc.output_arg_names(),output_names):
relevant_op_flags[i] = True
op_path = [ op_path = [
block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
] ]
...@@ -1688,7 +1752,10 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1688,7 +1752,10 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
raise "input must be in the same program as targets" raise "input must be in the same program as targets"
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set)
op_path_dict = dict()
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set,
op_path_dict)
no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set))) no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set)))
grad_to_var = dict() grad_to_var = dict()
grad_info_map = dict() grad_info_map = dict()
...@@ -1698,7 +1765,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -1698,7 +1765,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
block, block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
input_grad_names_set=input_grad_names_set) input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict)
# Because calc_gradient may be called multiple times, # Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have # we need rename the internal gradient variables so that they have
......
...@@ -199,10 +199,16 @@ class TestApiWhileLoop_Backward(unittest.TestCase): ...@@ -199,10 +199,16 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
def cond(i, x): def cond(i, x):
return layers.less_than(i, eleven) return layers.less_than(i, eleven)
def body(i, x): def body(j, x):
# TODO: In while block, if the var created in parent block
# participates in the calculation of gradient, the result of gradient
# is incorrect because each step scope always returns the same value
# generated by last step.
# Here we call `assign` op in while block to avoid this bug, and working on fixing it in next PR.
i = layers.assign(j)
x = layers.elementwise_mul(x=i, y=i) x = layers.elementwise_mul(x=i, y=i)
i = layers.increment(i) j = layers.increment(j)
return [i, x] return [j, x]
main_program = Program() main_program = Program()
startup_program = Program() startup_program = Program()
...@@ -232,7 +238,48 @@ class TestApiWhileLoop_Backward(unittest.TestCase): ...@@ -232,7 +238,48 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
'x': feed_x}, 'x': feed_x},
fetch_list=[mean.name, i.grad_name]) fetch_list=[mean.name, i.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data)) self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue(np.allclose(np.asarray(res[1]), i_grad)) self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
def test_while_loop_backward2(self):
def cond(i, x):
return i < 5
def body(i, x):
x = x + i
i = i + 1
return [i, x]
main_program = Program()
startup_program = Program()
with fluid.program_guard(main_program, startup_program):
i = fluid.data(name='i', shape=[1], dtype='float32')
i.stop_gradient = False
x = fluid.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
out = layers.while_loop(cond, body, [i, x])
mean = layers.mean(out[1])
append_backward(mean)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32')
data = np.asarray([11]).astype('float32')
i_grad = np.asarray([1]).astype('float32')
res = exe.run(main_program,
feed={'i': feed_i,
'x': feed_x},
fetch_list=[mean.name, i.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册