diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index c180e7b21042a1ceb2651d8f7a48a02c1abfd02c..ee17b11c8baaa1da0669ee55dfbeae4f3a0a3620 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -317,8 +317,7 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program - self.origin_program.global_block().delete_ops(self.optimize_ops) - self.origin_program.sync_with_cpp() + self.delete_ops(self.origin_program.global_block(), self.optimize_ops) # FIXME(typhoonzero): serialize once will fix error occurs when clone. self.origin_program.__str__() return self.origin_program @@ -602,8 +601,7 @@ class DistributeTranspiler: attrs={"axis": 0}) # delete lookup_table_op - program.global_block().delete_ops([op]) - program.sync_with_cpp() + self.delete_ops(program.global_block(), [op]) # break for loop break @@ -1166,3 +1164,12 @@ class DistributeTranspiler: in_name.startswith("beta2_pow_acc"): return True return False + + def delete_ops(self, block, ops): + try: + start = list(block.ops).index(ops[0]) + end = list(block.ops).index(ops[-1]) + [block.remove_op(start) for _ in xrange(end - start + 1)] + except Exception, e: + raise e + block.program.sync_with_cpp() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c9a48ea838787c04d1d8185e9a1272213613f906..ce9b880aeb31f7706820b89fb7e1edc6c51ba69b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -848,16 +848,6 @@ class Block(object): self.desc.remove_op(index, index + 1) del self.ops[index] - def delete_ops(self, ops): - # remove from cpp - # FIXME(typhoonzero): remove only the first occurrence. - try: - start = list(self.ops).index(ops[0]) - end = list(self.ops).index(ops[-1]) - [self.remove_op(start) for _ in xrange(end - start + 1)] - except Exception, e: - raise e - def slice_ops(self, start, end): return self.ops[start:end]