提交 7ffd50b9 编写于 作者: Y Yu Yang

Merge branch 'feature/add_fwd_block_id' of github.com:reyoung/Paddle into feature/add_fwd_block_id

...@@ -223,15 +223,15 @@ def get_cfgs(input_program): ...@@ -223,15 +223,15 @@ def get_cfgs(input_program):
# Find while/while_grad block pair # Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids: for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent forward_id = pdesc.block(grad_id).get_forward_block_idx()
if parent_id in while_sub_block_ids: if forward_id in while_sub_block_ids:
while_block_id_pair.append((parent_id, grad_id)) while_block_id_pair.append((forward_id, grad_id))
while_sub_block_ids.remove(parent_id) while_sub_block_ids.remove(forward_id)
# Get while/while_grad block ops # Get while/while_grad block ops
for parent_id, grad_id in while_block_id_pair: for forward_id, grad_id in while_block_id_pair:
while_block_ops = [] while_block_ops = []
while_block = pdesc.block(parent_id) while_block = pdesc.block(forward_id)
while_block_op_size = while_block.op_size() while_block_op_size = while_block.op_size()
for i in range(while_block_op_size): for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i)) while_block_ops.append(while_block.op(i))
...@@ -242,21 +242,21 @@ def get_cfgs(input_program): ...@@ -242,21 +242,21 @@ def get_cfgs(input_program):
while_block_ops.append(while_grad_block.op(i)) while_block_ops.append(while_grad_block.op(i))
while_op_output = set() while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names()) while_op_output.update(while_op_dict[forward_id].output_arg_names())
while_op_output.update(while_op_dict[grad_id].output_arg_names()) while_op_output.update(while_op_dict[grad_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output)) ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops # Process rest while block ops
for parent_id in while_sub_block_ids: for forward_id in while_sub_block_ids:
while_block_ops = [] while_block_ops = []
while_block = pdesc.block(parent_id) while_block = pdesc.block(forward_id)
while_block_op_size = while_block.op_size() while_block_op_size = while_block.op_size()
for i in range(while_block_op_size): for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i)) while_block_ops.append(while_block.op(i))
while_op_output = set() while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names()) while_op_output.update(while_op_dict[forward_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output)) ops_list.append((while_block_ops, while_block_op_size, while_op_output))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册