提交 ea522dab 编写于 作者: Y Yancey1989

refine delete ops

上级 41452582
...@@ -317,8 +317,7 @@ class DistributeTranspiler: ...@@ -317,8 +317,7 @@ class DistributeTranspiler:
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.origin_program.global_block().delete_ops(self.optimize_ops) self.delete_ops(self.origin_program.global_block(), self.optimize_ops)
self.origin_program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone. # FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.origin_program.__str__() self.origin_program.__str__()
return self.origin_program return self.origin_program
...@@ -602,8 +601,7 @@ class DistributeTranspiler: ...@@ -602,8 +601,7 @@ class DistributeTranspiler:
attrs={"axis": 0}) attrs={"axis": 0})
# delete lookup_table_op # delete lookup_table_op
program.global_block().delete_ops([op]) self.delete_ops(program.global_block(), [op])
program.sync_with_cpp()
# break for loop # break for loop
break break
...@@ -1166,3 +1164,12 @@ class DistributeTranspiler: ...@@ -1166,3 +1164,12 @@ class DistributeTranspiler:
in_name.startswith("beta2_pow_acc"): in_name.startswith("beta2_pow_acc"):
return True return True
return False return False
def delete_ops(self, block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
...@@ -848,16 +848,6 @@ class Block(object): ...@@ -848,16 +848,6 @@ class Block(object):
self.desc.remove_op(index, index + 1) self.desc.remove_op(index, index + 1)
del self.ops[index] del self.ops[index]
def delete_ops(self, ops):
# remove from cpp
# FIXME(typhoonzero): remove only the first occurrence.
try:
start = list(self.ops).index(ops[0])
end = list(self.ops).index(ops[-1])
[self.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
def slice_ops(self, start, end): def slice_ops(self, start, end):
return self.ops[start:end] return self.ops[start:end]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册