From abe32a49126d69fbbdbc8c31d210f4fe301a47f9 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 24 Jul 2018 12:44:46 +0800 Subject: [PATCH] "fix memory optimize bug in lr decay" (#12299) --- .../transpiler/memory_optimization_transpiler.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index dd90d66110e..353c82f7163 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -324,6 +324,8 @@ def _process_sub_block_pair(pdesc, sub_block_pair): sub_op_output = set() sub_op_output.update(sub_op_dict[fwd_id].output_arg_names()) sub_op_output.update(sub_op_dict[grad_id].output_arg_names()) + sub_op_output.update(sub_op_dict[fwd_id].input_arg_names()) + sub_op_output.update(sub_op_dict[grad_id].input_arg_names()) ops_list.append((sub_block_ops, block_op_size, sub_op_output)) # Process rest fwd_op block ops @@ -335,6 +337,7 @@ def _process_sub_block_pair(pdesc, sub_block_pair): sub_block_ops.append(sub_block.op(i)) sub_op_output = set() sub_op_output.update(sub_op_dict[fwd_id].output_arg_names()) + sub_op_output.update(sub_op_dict[fwd_id].input_arg_names()) ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output)) return ops_list @@ -349,13 +352,17 @@ def _get_cfgs(input_program): pdesc = input_program.get_desc() block_desc = pdesc.block(0) op_size = block_desc.op_size() - # Get global block ops - ops_list.append( - ([block_desc.op(i) for i in range(op_size)], op_size, set())) # Only process one level of nested subblock. ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR)) + skip_opt_set = set() + for _, _, skip_opt in ops_list: + skip_opt_set.update(skip_opt) + + # Get global block ops + ops_list.insert( + 0, ([block_desc.op(i) for i in range(op_size)], op_size, skip_opt_set)) cfgs = [ ControlFlowGraph(input_program, ops, forward_num, skip_opt) for ops, forward_num, skip_opt in ops_list -- GitLab