未验证 提交 ee7877e4 编写于 作者: L Leo Chen 提交者: GitHub

remove empty block program (#56355)

* remove empty block program

* update implementation
上级 1437ad06
...@@ -430,38 +430,75 @@ def _program_for_fthenb_and_1f1b(program): ...@@ -430,38 +430,75 @@ def _program_for_fthenb_and_1f1b(program):
bwd_prog = Program() bwd_prog = Program()
opt_prog = Program() opt_prog = Program()
for idx, src_block in enumerate(program.blocks): # split the program based on the op_role
if idx == 0: def _split_ops(block):
lr_block = lr_prog.block(0) lr_ops = []
fwd_block = fwd_prog.block(0) fwd_ops = []
bwd_block = bwd_prog.block(0) bwd_ops = []
opt_block = opt_prog.block(0) opt_ops = []
else:
lr_block = lr_prog._create_block(parent_idx=src_block.parent_idx)
fwd_block = fwd_prog._create_block(parent_idx=src_block.parent_idx)
bwd_block = bwd_prog._create_block(parent_idx=src_block.parent_idx)
opt_block = opt_prog._create_block(parent_idx=src_block.parent_idx)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
# split the program based on the op_role
for op in src_block.ops: for op in src_block.ops:
if is_lr_sched_op(op): if is_lr_sched_op(op):
_create_program(src_block, lr_block, op) lr_ops.append(op)
if is_forward_op(op): if is_forward_op(op):
_create_program(src_block, fwd_block, op) fwd_ops.append(op)
elif is_backward_op(op): elif is_backward_op(op):
_create_program(src_block, bwd_block, op) bwd_ops.append(op)
elif is_optimize_op(op): elif is_optimize_op(op):
_create_program(src_block, opt_block, op) opt_ops.append(op)
else: else:
raise ValueError( raise ValueError(
"The op role: " "The op role: "
+ str(op.attr('op_role')) + str(op.attr('op_role'))
+ " isn't one of LRSched, Forward, Backward or Optimizer." + " isn't one of LRSched, Forward, Backward or Optimizer."
) )
return lr_ops, fwd_ops, bwd_ops, opt_ops
def _add_ops_into_block(src_block, dst_block, ops):
for op in ops:
_create_program(src_block, dst_block, op)
for idx, src_block in enumerate(program.blocks):
lr_ops, fwd_ops, bwd_ops, opt_ops = _split_ops(src_block)
if idx == 0:
lr_block = lr_prog.block(0)
_add_ops_into_block(src_block, lr_block, lr_ops)
fwd_block = fwd_prog.block(0)
_add_ops_into_block(src_block, fwd_block, fwd_ops)
bwd_block = bwd_prog.block(0)
_add_ops_into_block(src_block, bwd_block, bwd_ops)
opt_block = opt_prog.block(0)
_add_ops_into_block(src_block, opt_block, opt_ops)
else:
if len(lr_ops):
lr_block = lr_prog._create_block(
parent_idx=src_block.parent_idx
)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, lr_block, lr_ops)
if len(fwd_ops):
fwd_block = fwd_prog._create_block(
parent_idx=src_block.parent_idx
)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, fwd_block, fwd_ops)
if len(bwd_ops):
bwd_block = bwd_prog._create_block(
parent_idx=src_block.parent_idx
)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, bwd_block, bwd_ops)
if len(opt_ops):
opt_block = opt_prog._create_block(
parent_idx=src_block.parent_idx
)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, opt_block, opt_ops)
lr_prog._sync_with_cpp() lr_prog._sync_with_cpp()
fwd_prog._sync_with_cpp() fwd_prog._sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册