From ea522dabc9b20497945d157ce61844d5faadf3aa Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 3 May 2018 17:36:30 +0800 Subject: [PATCH] refine delete ops --- python/paddle/fluid/distribute_transpiler.py | 15 +++++++++++---- python/paddle/fluid/framework.py | 10 ---------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index c180e7b2104..ee17b11c8ba 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -317,8 +317,7 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program - self.origin_program.global_block().delete_ops(self.optimize_ops) - self.origin_program.sync_with_cpp() + self.delete_ops(self.origin_program.global_block(), self.optimize_ops) # FIXME(typhoonzero): serialize once will fix error occurs when clone. self.origin_program.__str__() return self.origin_program @@ -602,8 +601,7 @@ class DistributeTranspiler: attrs={"axis": 0}) # delete lookup_table_op - program.global_block().delete_ops([op]) - program.sync_with_cpp() + self.delete_ops(program.global_block(), [op]) # break for loop break @@ -1166,3 +1164,12 @@ class DistributeTranspiler: in_name.startswith("beta2_pow_acc"): return True return False + + def delete_ops(self, block, ops): + try: + start = list(block.ops).index(ops[0]) + end = list(block.ops).index(ops[-1]) + [block.remove_op(start) for _ in xrange(end - start + 1)] + except Exception, e: + raise e + block.program.sync_with_cpp() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c9a48ea8387..ce9b880aeb3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -848,16 +848,6 @@ class Block(object): self.desc.remove_op(index, index + 1) del self.ops[index] - def delete_ops(self, ops): - # remove from cpp - # FIXME(typhoonzero): remove only the first occurrence. - try: - start = list(self.ops).index(ops[0]) - end = list(self.ops).index(ops[-1]) - [self.remove_op(start) for _ in xrange(end - start + 1)] - except Exception, e: - raise e - def slice_ops(self, start, end): return self.ops[start:end] -- GitLab