提交 76e3ec60 编写于 作者: Y Yancey1989

fix cloned op

上级 8cb494f7
...@@ -396,7 +396,7 @@ class DistributeTranspiler(object): ...@@ -396,7 +396,7 @@ class DistributeTranspiler(object):
return varname return varname
return "" return ""
def __clone_lr_op_sub_block__(op, program, new_block): def __clone_lr_op_sub_block__(op, program, lr_block):
if not op.has_attr('sub_block'): if not op.has_attr('sub_block'):
return return
...@@ -405,17 +405,17 @@ class DistributeTranspiler(object): ...@@ -405,17 +405,17 @@ class DistributeTranspiler(object):
assert isinstance(origin_block, Block) assert isinstance(origin_block, Block)
# we put the new sub block to new block to follow the block # we put the new sub block to new block to follow the block
# hierarchy of the original blocks # hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx) new_sub_block = program.create_block(lr_block.idx)
# clone vars # clone vars
for var in origin_block.vars: for var in origin_block.vars:
new_sub_block.clone_variable(var) new_sub_block.clone_variable(var)
# clone ops # clone ops
for op in origin_block.ops: for origin_op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op) cloned_op = self._clone_lr_op(program, new_sub_block, origin_op)
# clone sub_block of op # clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block) __clone_lr_op_sub_block__(cloned_op, program, new_sub_block)
# reset the block of op # reset the block of op
op.set_attr('sub_block', new_sub_block) op.set_attr('sub_block', new_sub_block)
...@@ -429,9 +429,10 @@ class DistributeTranspiler(object): ...@@ -429,9 +429,10 @@ class DistributeTranspiler(object):
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
optimize_blocks.append(lr_decay_block) optimize_blocks.append(lr_decay_block)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) cloned_op = self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op # append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block) __clone_lr_op_sub_block__(cloned_op, pserver_program,
lr_decay_block)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
...@@ -1214,7 +1215,7 @@ class DistributeTranspiler(object): ...@@ -1214,7 +1215,7 @@ class DistributeTranspiler(object):
if var not in program.global_block().vars: if var not in program.global_block().vars:
block.clone_variable(var) block.clone_variable(var)
block.append_op( return block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs) type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
...@@ -1252,7 +1253,7 @@ class DistributeTranspiler(object): ...@@ -1252,7 +1253,7 @@ class DistributeTranspiler(object):
elif not program.global_block().vars.has_key(var.name): elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var) program.global_block().clone_variable(var)
optimize_block.append_op( return optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册