From 83df277ff123d7b102f405cdb512457841f11a32 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Wed, 7 Feb 2018 17:33:27 +0800 Subject: [PATCH] Refine get_cfgs method of memory optimization transpiler (#8080) * refine get cfgs method in memory optimization transpiler * clean code --- .../fluid/memory_optimization_transpiler.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index 2b00923f5e..11e2cfb3cc 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -145,7 +145,6 @@ class ControlFlowGraph(object): if op.type() == "while" or op.type() == "while_grad": continue block_desc = op.block() - self.current_block_desc = block_desc is_forward = i < self._forward_num if self.pool: defs_can_optimize = filter( @@ -208,17 +207,17 @@ def get_cfgs(input_program): while_sub_block_ids = [] while_grad_sub_block_ids = [] - while_op_output = set() while_block_id_pair = [] + while_op_dict = {} for i in range(op_size): op = block_desc.op(i) if op.type() == "while": while_sub_block_ids.append(op.attr("sub_block").id) - while_op_output.update(op.output_arg_names()) + while_op_dict[op.attr("sub_block").id] = op elif op.type() == "while_grad": while_grad_sub_block_ids.append(op.attr("sub_block").id) - while_op_output.update(op.output_arg_names()) + while_op_dict[op.attr("sub_block").id] = op # Find while/while_grad block pair for grad_id in while_grad_sub_block_ids: @@ -240,6 +239,10 @@ def get_cfgs(input_program): for i in range(while_grad_block_op_size): while_block_ops.append(while_grad_block.op(i)) + while_op_output = set() + while_op_output.update(while_op_dict[parent_id].output_arg_names()) + while_op_output.update(while_op_dict[grad_id].output_arg_names()) + ops_list.append((while_block_ops, while_block_op_size, while_op_output)) # Process rest while block ops @@ -250,9 +253,15 @@ def get_cfgs(input_program): for i in range(while_block_op_size): while_block_ops.append(while_block.op(i)) - ops_list.append((while_block_ops, while_block_op_size)) + while_op_output = set() + while_op_output.update(while_op_dict[parent_id].output_arg_names()) + + ops_list.append((while_block_ops, while_block_op_size, while_op_output)) - cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list] + cfgs = [ + ControlFlowGraph(input_program, ops, forward_num, skip_opt) + for ops, forward_num, skip_opt in ops_list + ] return cfgs -- GitLab