提交 76e3ec60 编写于 作者: Y Yancey1989

fix cloned op

上级 8cb494f7
......@@ -396,7 +396,7 @@ class DistributeTranspiler(object):
return varname
return ""
def __clone_lr_op_sub_block__(op, program, new_block):
def __clone_lr_op_sub_block__(op, program, lr_block):
if not op.has_attr('sub_block'):
return
......@@ -405,17 +405,17 @@ class DistributeTranspiler(object):
assert isinstance(origin_block, Block)
# we put the new sub block to new block to follow the block
# hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx)
new_sub_block = program.create_block(lr_block.idx)
# clone vars
for var in origin_block.vars:
new_sub_block.clone_variable(var)
# clone ops
for op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op)
for origin_op in origin_block.ops:
cloned_op = self._clone_lr_op(program, new_sub_block, origin_op)
# clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block)
__clone_lr_op_sub_block__(cloned_op, program, new_sub_block)
# reset the block of op
op.set_attr('sub_block', new_sub_block)
......@@ -429,9 +429,10 @@ class DistributeTranspiler(object):
pserver_program.num_blocks - 1)
optimize_blocks.append(lr_decay_block)
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op)
cloned_op = self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block)
__clone_lr_op_sub_block__(cloned_op, pserver_program,
lr_decay_block)
# append op to the current block
grad_to_block_id = []
......@@ -1214,7 +1215,7 @@ class DistributeTranspiler(object):
if var not in program.global_block().vars:
block.clone_variable(var)
block.append_op(
return block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
......@@ -1252,7 +1253,7 @@ class DistributeTranspiler(object):
elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var)
optimize_block.append_op(
return optimize_block.append_op(
type=opt_op.type,
inputs=inputs,
outputs=outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册