“179dda3ea41f0f603cdfd092babb3c82c6c3564e”上不存在“mobile/test/net/test_mobilenet_combine.cpp”
未验证 提交 ee7877e4 编写于 作者: L Leo Chen 提交者: GitHub

remove empty block program (#56355)

* remove empty block program

* update implementation
上级 1437ad06
......@@ -430,38 +430,75 @@ def _program_for_fthenb_and_1f1b(program):
bwd_prog = Program()
opt_prog = Program()
for idx, src_block in enumerate(program.blocks):
if idx == 0:
lr_block = lr_prog.block(0)
fwd_block = fwd_prog.block(0)
bwd_block = bwd_prog.block(0)
opt_block = opt_prog.block(0)
else:
lr_block = lr_prog._create_block(parent_idx=src_block.parent_idx)
fwd_block = fwd_prog._create_block(parent_idx=src_block.parent_idx)
bwd_block = bwd_prog._create_block(parent_idx=src_block.parent_idx)
opt_block = opt_prog._create_block(parent_idx=src_block.parent_idx)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
# split the program based on the op_role
# split the program based on the op_role
def _split_ops(block):
lr_ops = []
fwd_ops = []
bwd_ops = []
opt_ops = []
for op in src_block.ops:
if is_lr_sched_op(op):
_create_program(src_block, lr_block, op)
lr_ops.append(op)
if is_forward_op(op):
_create_program(src_block, fwd_block, op)
fwd_ops.append(op)
elif is_backward_op(op):
_create_program(src_block, bwd_block, op)
bwd_ops.append(op)
elif is_optimize_op(op):
_create_program(src_block, opt_block, op)
opt_ops.append(op)
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of LRSched, Forward, Backward or Optimizer."
)
return lr_ops, fwd_ops, bwd_ops, opt_ops
def _add_ops_into_block(src_block, dst_block, ops):
for op in ops:
_create_program(src_block, dst_block, op)
for idx, src_block in enumerate(program.blocks):
lr_ops, fwd_ops, bwd_ops, opt_ops = _split_ops(src_block)
if idx == 0:
lr_block = lr_prog.block(0)
_add_ops_into_block(src_block, lr_block, lr_ops)
fwd_block = fwd_prog.block(0)
_add_ops_into_block(src_block, fwd_block, fwd_ops)
bwd_block = bwd_prog.block(0)
_add_ops_into_block(src_block, bwd_block, bwd_ops)
opt_block = opt_prog.block(0)
_add_ops_into_block(src_block, opt_block, opt_ops)
else:
if len(lr_ops):
lr_block = lr_prog._create_block(
parent_idx=src_block.parent_idx
)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, lr_block, lr_ops)
if len(fwd_ops):
fwd_block = fwd_prog._create_block(
parent_idx=src_block.parent_idx
)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, fwd_block, fwd_ops)
if len(bwd_ops):
bwd_block = bwd_prog._create_block(
parent_idx=src_block.parent_idx
)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, bwd_block, bwd_ops)
if len(opt_ops):
opt_block = opt_prog._create_block(
parent_idx=src_block.parent_idx
)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, opt_block, opt_ops)
lr_prog._sync_with_cpp()
fwd_prog._sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册