未验证 提交 83df277f 编写于 作者: Q QI JUN 提交者: GitHub

Refine get_cfgs method of memory optimization transpiler (#8080)

* refine get cfgs method in memory optimization transpiler

* clean code
上级 b41205d9
...@@ -145,7 +145,6 @@ class ControlFlowGraph(object): ...@@ -145,7 +145,6 @@ class ControlFlowGraph(object):
if op.type() == "while" or op.type() == "while_grad": if op.type() == "while" or op.type() == "while_grad":
continue continue
block_desc = op.block() block_desc = op.block()
self.current_block_desc = block_desc
is_forward = i < self._forward_num is_forward = i < self._forward_num
if self.pool: if self.pool:
defs_can_optimize = filter( defs_can_optimize = filter(
...@@ -208,17 +207,17 @@ def get_cfgs(input_program): ...@@ -208,17 +207,17 @@ def get_cfgs(input_program):
while_sub_block_ids = [] while_sub_block_ids = []
while_grad_sub_block_ids = [] while_grad_sub_block_ids = []
while_op_output = set()
while_block_id_pair = [] while_block_id_pair = []
while_op_dict = {}
for i in range(op_size): for i in range(op_size):
op = block_desc.op(i) op = block_desc.op(i)
if op.type() == "while": if op.type() == "while":
while_sub_block_ids.append(op.attr("sub_block").id) while_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names()) while_op_dict[op.attr("sub_block").id] = op
elif op.type() == "while_grad": elif op.type() == "while_grad":
while_grad_sub_block_ids.append(op.attr("sub_block").id) while_grad_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names()) while_op_dict[op.attr("sub_block").id] = op
# Find while/while_grad block pair # Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids: for grad_id in while_grad_sub_block_ids:
...@@ -240,6 +239,10 @@ def get_cfgs(input_program): ...@@ -240,6 +239,10 @@ def get_cfgs(input_program):
for i in range(while_grad_block_op_size): for i in range(while_grad_block_op_size):
while_block_ops.append(while_grad_block.op(i)) while_block_ops.append(while_grad_block.op(i))
while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names())
while_op_output.update(while_op_dict[grad_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output)) ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops # Process rest while block ops
...@@ -250,9 +253,15 @@ def get_cfgs(input_program): ...@@ -250,9 +253,15 @@ def get_cfgs(input_program):
for i in range(while_block_op_size): for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i)) while_block_ops.append(while_block.op(i))
ops_list.append((while_block_ops, while_block_op_size)) while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list] cfgs = [
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
for ops, forward_num, skip_opt in ops_list
]
return cfgs return cfgs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册