提交 bf92706c 编写于 作者: Q qijun

fix bug in memory optimization transpiler

上级 14f83707
......@@ -223,15 +223,15 @@ def get_cfgs(input_program):
# Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent
if parent_id in while_sub_block_ids:
while_block_id_pair.append((parent_id, grad_id))
while_sub_block_ids.remove(parent_id)
forward_id = pdesc.block(grad_id).get_forward_block_idx()
if forward_id in while_sub_block_ids:
while_block_id_pair.append((forward_id, grad_id))
while_sub_block_ids.remove(forward_id)
# Get while/while_grad block ops
for parent_id, grad_id in while_block_id_pair:
for forward_id, grad_id in while_block_id_pair:
while_block_ops = []
while_block = pdesc.block(parent_id)
while_block = pdesc.block(forward_id)
while_block_op_size = while_block.op_size()
for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i))
......@@ -242,21 +242,21 @@ def get_cfgs(input_program):
while_block_ops.append(while_grad_block.op(i))
while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names())
while_op_output.update(while_op_dict[forward_id].output_arg_names())
while_op_output.update(while_op_dict[grad_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops
for parent_id in while_sub_block_ids:
for forward_id in while_sub_block_ids:
while_block_ops = []
while_block = pdesc.block(parent_id)
while_block = pdesc.block(forward_id)
while_block_op_size = while_block.op_size()
for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i))
while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names())
while_op_output.update(while_op_dict[forward_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册