From 435fc4f0af3967d519d0b4d05aa00aae674c93f7 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Mon, 13 Jul 2020 10:33:59 +0800 Subject: [PATCH] [while grad]Support pruning op in find_op_path about while sub-block when appending backward (#25330) Prune OPs which are not related with loss in while sub-block when constructing backward OP path. --- python/paddle/fluid/backward.py | 148 +++++++++++++----- .../tests/unittests/test_while_loop_op.py | 57 ++++++- 2 files changed, 160 insertions(+), 45 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index e72f7a04e60..898c7d29564 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -45,7 +45,7 @@ class ProgramStats(object): input_names = [] for name in self.var_op_deps: if len(self.var_op_deps[name]["var_as_output_ops"]) == 0 and \ - len(self.var_op_deps[name]["var_as_input_ops"]) > 0: + len(self.var_op_deps[name]["var_as_input_ops"]) > 0: if self.block.var(name).persistable: continue input_names.append(name) @@ -433,7 +433,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx): ] + arg_names[arg_idx:] new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \ - str(var_rename_count[var_name]) + str(var_rename_count[var_name]) var_rename_count[var_name] += 1 arg_names[arg_idx] = new_name op_desc.set_output(param_name, arg_names) @@ -611,7 +611,7 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set): not_need_op_descs_set = set(not_need_op_descs) grad_op_descs_set = set(grad_op_descs) # If a backward computational graph is simply one sub-graph header, the - # not_need_op_descs will be whole graph, this IF clause avoids it. + # not_need_op_descs will be whole graph, this IF clause avoids it. if grad_op_descs_set == not_need_op_descs_set: return set() return not_need_op_descs_set @@ -662,7 +662,7 @@ def _append_backward_ops_with_checkpoints_( checkpoints_name = list(set(checkpoints_name)) local_block = block.program._create_block() buffer_block = block.program._create_block() - # 0) deal with forward recomputing program descs + # 0) deal with forward recomputing program descs program_stat = ProgramStats(block, ops) program_stat.modify_forward_desc_for_recompute() program_stat.build_stats() @@ -797,32 +797,51 @@ def _append_backward_ops_with_checkpoints_( return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments -def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set): +def _get_sub_block_path(sub_block, + sub_block_op_desc, + no_grad_set, + op_path_dict, + sub_block_target_names=None): """ Get output vars in subblock which will be assigned to parent block. - It is used to find the grad path in subblock + It is used to find the grad path in subblock. + + Args: + sub_block(Block): The sub-block in which to get op path. + sub_block_op_desc: The op desc of the sub-block op such as 'while', 'conditional_block' and 'recurrent'. + no_grad_set(set): The set of no grad var name. no_grad_set will be changed. + op_path_dict(dict): op_path_dict will be changed. + key(int) block index + val(list) the op path of block(index) + sub_block_target_names(set): Target var names of sub-block. + Return: + The forward op path of sub-block corresponding to backward op. """ + assert sub_block_op_desc.has_attr( "sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id( "sub_block") - # TODO(huihuangzheng): add support for recurrent op and while op - if sub_block_op_desc.type == "conditional_block": - sub_outputs = [] - sub_assign_to_out_ops = [] - for var in sub_block_op_desc.output_arg_names: + assert isinstance(sub_block_target_names, (set, type(None))) + + if sub_block_target_names is None: + sub_block_target_names = sub_block_op_desc.output_arg_names + + # TODO(huihuangzheng): add support for recurrent op. + if sub_block_op_desc.type in ["conditional_block", "while"]: + # Step1: get the output vars in sub-block + sub_outputs = [ + sub_block._var_recursive(var) for var in sub_block_target_names + ] + for var in sub_block_target_names: for op_desc in sub_block.ops: - if op_desc.type == "assign" and var in op_desc.output_arg_names: - sub_assign_to_out_ops.append(op_desc) + if var in op_desc.output_arg_names: for name in op_desc.input_arg_names: - if sub_block.has_var(name): - sub_outputs.append(sub_block.var(name)) + sub_outputs.append(sub_block._var_recursive(name)) + # Step2: find op path of sub-block + is_while = sub_block_op_desc.type in ["while"] sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [], - no_grad_set) - # TODO better way than finding in list - for op_desc in sub_assign_to_out_ops: - if op_desc not in sub_block_op_path: - sub_block_op_path.append(op_desc) + no_grad_set, op_path_dict, is_while) return sub_block_op_path return sub_block.ops @@ -846,7 +865,8 @@ def _append_backward_ops_(block, no_grad_dict, grad_to_var, callbacks=None, - input_grad_names_set=None): + input_grad_names_set=None, + op_path_dict=None): """ Create all grad ops, and insert them into given block @@ -864,6 +884,9 @@ def _append_backward_ops_(block, input_grad_names_set(set): this set is used to store the gradients' name which is generated by backward ops, and input_grad_names_set can help to prune the unnecessary backward ops. + op_path_dict(dict): op_path_dict will be changed. + key(int) block index + val(list) the op path of block(index) """ if callbacks is not None: assert (isinstance(callbacks, list)) @@ -888,11 +911,10 @@ def _append_backward_ops_(block, # see follwing comments for why set None here. pre_input_grad_names_set = copy.copy(input_grad_names_set) input_grad_names_set = None - sub_block_path = _get_sub_block_path(sub_block, op, - no_grad_dict[sub_block.idx]) + sub_block_path = op_path_dict[op._block_attr_id("sub_block")] _append_backward_ops_(sub_block, sub_block_path, grad_sub_block, no_grad_dict, grad_to_var, callbacks, - input_grad_names_set) + input_grad_names_set, op_path_dict) input_grad_names_set = pre_input_grad_names_set program._rollback() @@ -1013,7 +1035,7 @@ def _find_parent_op_(sub_block): "sub_block") == sub_block_id: return op - # NOTE(paddle-dev): When optimizer is added in conditional block, + # NOTE(paddle-dev): When optimizer is added in conditional block, # sub_block may not be found. return None @@ -1072,7 +1094,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): if var != core.empty_var_name() ] - # If the outputs of grad op is empty, just remove it + # If the outputs of grad op is empty, just remove it if not outputs: ops_to_remove.append(op_idx) continue @@ -1358,7 +1380,10 @@ def append_backward(loss, block_no_grad_set = set( map(_strip_grad_suffix_, no_grad_dict[block_idx])) - op_path = _find_op_path_(block, [loss], [], block_no_grad_set) + + op_path_dict = dict() + op_path = _find_op_path_(block, [loss], [], block_no_grad_set, + op_path_dict) no_grad_vars = _find_no_grad_vars(block, op_path, [loss], block_no_grad_set) @@ -1371,13 +1396,13 @@ def append_backward(loss, # For double backward, input_grad_names is used for filtering # some non-used gradients op(s). - # Todo(liym27): need a better design. + # TODO(liym27): need a better design. # not support double grad in control flow sub-block now. if not is_in_control_flow: if program._appending_grad_times > 1: input_grad_names_set = set([_append_grad_suffix_(loss.name)]) - # Todo: support _append_backward_ops_with_checkpoints_ in + # TODO: support _append_backward_ops_with_checkpoints_ in # sub-block (control flow) if checkpoints != None and \ isinstance(checkpoints, list) and \ @@ -1400,7 +1425,8 @@ def append_backward(loss, no_grad_dict, grad_to_var, callbacks, - input_grad_names_set=input_grad_names_set) + input_grad_names_set=input_grad_names_set, + op_path_dict=op_path_dict) grad_info_map = dict() @@ -1508,13 +1534,14 @@ def _get_output_names(cur_block, targets): """ block = targets[0].block if targets else cur_block - prog = cur_block.program - if _is_ancestor_block(block, cur_block): - return set() - current_output_names = set([out.name for out in targets]) - # if `cur_block` is an ancestor of `targets[0].block`, run while loop + # 1. If `targets` in cur_block or the ancestral block of `cur_block` + if block.idx == cur_block.idx or _is_ancestor_block(block, cur_block): + return current_output_names + + # 2. If `cur_block` is an ancestor of `targets[0].block`, run while loop + prog = cur_block.program while block.idx != cur_block.idx: assert block.parent_idx != -1 parent_block = prog.block(block.parent_idx) @@ -1554,12 +1581,32 @@ def _find_no_grad_vars(block, op_path, targets, no_grad_set): return set(no_grad_var) -def _find_op_path_(block, outputs, inputs, no_grad_set): +def _find_op_path_(block, + targets, + inputs, + no_grad_set, + op_path_dict=None, + is_while=False): """ - no_grad_set will also be changed + It is used to find the grad path in `block`. + + Args: + block(Block): The block in which to get op path. + targets(list[Variable]): The target variables. + inputs(list[Variable]): The input variables. + no_grad_set(set): The set of no grad var name. no_grad_set will be changed. + op_path_dict(dict): op_path_dict will be changed. op_path_dict will be changed. + key(int) block index + val(list) the op path of block(index) + is_while(bool): Whether or not `block` is while block + Return: + The forward op path of block corresponding to backward op. """ + input_names = set([inp.name for inp in inputs]) - output_names = _get_output_names(block, outputs) + output_names = _get_output_names(block, targets) + if op_path_dict is None: + op_path_dict = dict() relevant_op_flags = [True] * len(block.ops) @@ -1576,6 +1623,15 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): relevant_op_flags[i] = False for i, op in reversed(list(enumerate(block.ops))): + if op.has_attr("sub_block"): + sub_block_id = op._block_attr_id("sub_block") + sub_block = block.program.block(sub_block_id) + sub_block_target_names = output_names & set(op.output_arg_names) + sub_block_path = _get_sub_block_path(sub_block, op, + set(), op_path_dict, + sub_block_target_names) + op_path_dict[sub_block_id] = sub_block_path + if _some_in_set_( op.desc.output_arg_names(), output_names) and core.has_non_empty_grad_op_maker(op.type): @@ -1585,6 +1641,14 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): else: relevant_op_flags[i] = False + if is_while: + # If block is while block, dealing with op specifically again. + # TODO(liym27): Consider special types of ops. + for i, op in reversed(list(enumerate(block.ops))): + if relevant_op_flags[i] == False \ + and _some_in_set_(op.desc.output_arg_names(),output_names): + relevant_op_flags[i] = True + op_path = [ block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] ] @@ -1688,7 +1752,10 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): raise "input must be in the same program as targets" block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) - op_path = _find_op_path_(block, targets, inputs, block_no_grad_set) + + op_path_dict = dict() + op_path = _find_op_path_(block, targets, inputs, block_no_grad_set, + op_path_dict) no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set))) grad_to_var = dict() grad_info_map = dict() @@ -1698,7 +1765,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): block, no_grad_dict, grad_to_var, - input_grad_names_set=input_grad_names_set) + input_grad_names_set=input_grad_names_set, + op_path_dict=op_path_dict) # Because calc_gradient may be called multiple times, # we need rename the internal gradient variables so that they have diff --git a/python/paddle/fluid/tests/unittests/test_while_loop_op.py b/python/paddle/fluid/tests/unittests/test_while_loop_op.py index 224dfd7f0a7..aa692eb5367 100644 --- a/python/paddle/fluid/tests/unittests/test_while_loop_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_loop_op.py @@ -199,10 +199,16 @@ class TestApiWhileLoop_Backward(unittest.TestCase): def cond(i, x): return layers.less_than(i, eleven) - def body(i, x): + def body(j, x): + # TODO: In while block, if the var created in parent block + # participates in the calculation of gradient, the result of gradient + # is incorrect because each step scope always returns the same value + # generated by last step. + # Here we call `assign` op in while block to avoid this bug, and working on fixing it in next PR. + i = layers.assign(j) x = layers.elementwise_mul(x=i, y=i) - i = layers.increment(i) - return [i, x] + j = layers.increment(j) + return [j, x] main_program = Program() startup_program = Program() @@ -232,7 +238,48 @@ class TestApiWhileLoop_Backward(unittest.TestCase): 'x': feed_x}, fetch_list=[mean.name, i.grad_name]) self.assertTrue(np.allclose(np.asarray(res[0]), data)) - self.assertTrue(np.allclose(np.asarray(res[1]), i_grad)) + self.assertTrue( + np.allclose(np.asarray(res[1]), i_grad), + msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad)) + + def test_while_loop_backward2(self): + def cond(i, x): + return i < 5 + + def body(i, x): + x = x + i + i = i + 1 + return [i, x] + + main_program = Program() + startup_program = Program() + with fluid.program_guard(main_program, startup_program): + i = fluid.data(name='i', shape=[1], dtype='float32') + i.stop_gradient = False + x = fluid.data(name='x', shape=[1], dtype='float32') + x.stop_gradient = False + + out = layers.while_loop(cond, body, [i, x]) + mean = layers.mean(out[1]) + append_backward(mean) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + feed_i = np.ones(1).astype('float32') + feed_x = np.ones(1).astype('float32') + data = np.asarray([11]).astype('float32') + i_grad = np.asarray([1]).astype('float32') + + res = exe.run(main_program, + feed={'i': feed_i, + 'x': feed_x}, + fetch_list=[mean.name, i.grad_name]) + self.assertTrue(np.allclose(np.asarray(res[0]), data)) + self.assertTrue( + np.allclose(np.asarray(res[1]), i_grad), + msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad)) class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): @@ -410,7 +457,7 @@ class TestApiWhileLoop_Error(unittest.TestCase): ten = layers.fill_constant(shape=[1], dtype='int64', value=10) ten_2d = layers.fill_constant(shape=[2, 2], dtype='int64', value=10) - # The type of `cond` in Op(while_loop) must be callable + # The type of `cond` in Op(while_loop) must be callable def type_error_cond(): out = layers.while_loop(data, body, [data_1d]) -- GitLab