From bf92706c58f8c89db9b670523e8aa4fcd2c067a7 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 23 Feb 2018 11:40:30 +0800 Subject: [PATCH] fix bug in memory optimization transpiler --- .../fluid/memory_optimization_transpiler.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index ee56ccdcf1..6952ca7fe4 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -223,15 +223,15 @@ def get_cfgs(input_program): # Find while/while_grad block pair for grad_id in while_grad_sub_block_ids: - parent_id = pdesc.block(grad_id).parent - if parent_id in while_sub_block_ids: - while_block_id_pair.append((parent_id, grad_id)) - while_sub_block_ids.remove(parent_id) + forward_id = pdesc.block(grad_id).get_forward_block_idx() + if forward_id in while_sub_block_ids: + while_block_id_pair.append((forward_id, grad_id)) + while_sub_block_ids.remove(forward_id) # Get while/while_grad block ops - for parent_id, grad_id in while_block_id_pair: + for forward_id, grad_id in while_block_id_pair: while_block_ops = [] - while_block = pdesc.block(parent_id) + while_block = pdesc.block(forward_id) while_block_op_size = while_block.op_size() for i in range(while_block_op_size): while_block_ops.append(while_block.op(i)) @@ -242,21 +242,21 @@ def get_cfgs(input_program): while_block_ops.append(while_grad_block.op(i)) while_op_output = set() - while_op_output.update(while_op_dict[parent_id].output_arg_names()) + while_op_output.update(while_op_dict[forward_id].output_arg_names()) while_op_output.update(while_op_dict[grad_id].output_arg_names()) ops_list.append((while_block_ops, while_block_op_size, while_op_output)) # Process rest while block ops - for parent_id in while_sub_block_ids: + for forward_id in while_sub_block_ids: while_block_ops = [] - while_block = pdesc.block(parent_id) + while_block = pdesc.block(forward_id) while_block_op_size = while_block.op_size() for i in range(while_block_op_size): while_block_ops.append(while_block.op(i)) while_op_output = set() - while_op_output.update(while_op_dict[parent_id].output_arg_names()) + while_op_output.update(while_op_dict[forward_id].output_arg_names()) ops_list.append((while_block_ops, while_block_op_size, while_op_output)) -- GitLab