diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index afb935008167a16bf765f2b78d4247701b89cc2f..7fba8b76b4ac264facadc7f02bc69f0ded219df8 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -430,38 +430,75 @@ def _program_for_fthenb_and_1f1b(program): bwd_prog = Program() opt_prog = Program() - for idx, src_block in enumerate(program.blocks): - if idx == 0: - lr_block = lr_prog.block(0) - fwd_block = fwd_prog.block(0) - bwd_block = bwd_prog.block(0) - opt_block = opt_prog.block(0) - else: - lr_block = lr_prog._create_block(parent_idx=src_block.parent_idx) - fwd_block = fwd_prog._create_block(parent_idx=src_block.parent_idx) - bwd_block = bwd_prog._create_block(parent_idx=src_block.parent_idx) - opt_block = opt_prog._create_block(parent_idx=src_block.parent_idx) - lr_block._set_forward_block_idx(src_block.forward_block_idx) - fwd_block._set_forward_block_idx(src_block.forward_block_idx) - bwd_block._set_forward_block_idx(src_block.forward_block_idx) - opt_block._set_forward_block_idx(src_block.forward_block_idx) - - # split the program based on the op_role + # split the program based on the op_role + def _split_ops(block): + lr_ops = [] + fwd_ops = [] + bwd_ops = [] + opt_ops = [] for op in src_block.ops: if is_lr_sched_op(op): - _create_program(src_block, lr_block, op) + lr_ops.append(op) if is_forward_op(op): - _create_program(src_block, fwd_block, op) + fwd_ops.append(op) elif is_backward_op(op): - _create_program(src_block, bwd_block, op) + bwd_ops.append(op) elif is_optimize_op(op): - _create_program(src_block, opt_block, op) + opt_ops.append(op) else: raise ValueError( "The op role: " + str(op.attr('op_role')) + " isn't one of LRSched, Forward, Backward or Optimizer." ) + return lr_ops, fwd_ops, bwd_ops, opt_ops + + def _add_ops_into_block(src_block, dst_block, ops): + for op in ops: + _create_program(src_block, dst_block, op) + + for idx, src_block in enumerate(program.blocks): + lr_ops, fwd_ops, bwd_ops, opt_ops = _split_ops(src_block) + if idx == 0: + lr_block = lr_prog.block(0) + _add_ops_into_block(src_block, lr_block, lr_ops) + + fwd_block = fwd_prog.block(0) + _add_ops_into_block(src_block, fwd_block, fwd_ops) + + bwd_block = bwd_prog.block(0) + _add_ops_into_block(src_block, bwd_block, bwd_ops) + + opt_block = opt_prog.block(0) + _add_ops_into_block(src_block, opt_block, opt_ops) + else: + if len(lr_ops): + lr_block = lr_prog._create_block( + parent_idx=src_block.parent_idx + ) + lr_block._set_forward_block_idx(src_block.forward_block_idx) + _add_ops_into_block(src_block, lr_block, lr_ops) + + if len(fwd_ops): + fwd_block = fwd_prog._create_block( + parent_idx=src_block.parent_idx + ) + fwd_block._set_forward_block_idx(src_block.forward_block_idx) + _add_ops_into_block(src_block, fwd_block, fwd_ops) + + if len(bwd_ops): + bwd_block = bwd_prog._create_block( + parent_idx=src_block.parent_idx + ) + bwd_block._set_forward_block_idx(src_block.forward_block_idx) + _add_ops_into_block(src_block, bwd_block, bwd_ops) + + if len(opt_ops): + opt_block = opt_prog._create_block( + parent_idx=src_block.parent_idx + ) + opt_block._set_forward_block_idx(src_block.forward_block_idx) + _add_ops_into_block(src_block, opt_block, opt_ops) lr_prog._sync_with_cpp() fwd_prog._sync_with_cpp()