From 103407aa6027c676cd1bcc13e6f35b2a51f1cc6a Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 3 Apr 2018 16:09:27 +0800 Subject: [PATCH] refine sync_with_cpp when remove ops or remove vars --- python/paddle/fluid/framework.py | 18 +++++++++++++++ .../tests/unittests/test_protobuf_descs.py | 23 +++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3e78788f47..e15456bfc0 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -847,6 +847,11 @@ class Block(object): if not self.has_var(var.name()): self.create_var(name=var.name(), desc=var, type=var.type()) + # sync variables removed from c++ end + for var in self.vars.keys(): + if not self.desc.find_var(var): + self.vars.pop(var) + # sync operators from cpp ops_in_cpp = [] for op_idx in range(0, self.desc.op_size()): @@ -881,6 +886,19 @@ class Block(object): op = Operator(self, op_desc) self.ops.append(op) + # sync ops removed from c++ end + if end_index != -1 and end_index < len(self.ops): + ops_in_cpp_index = 0 + ops_in_python_index = 0 + while ops_in_python_index < len( + self.ops) and ops_in_cpp_index < len(ops_in_cpp): + if self.ops[ops_in_python_index].desc != ops_in_cpp[ + ops_in_cpp_index]: + del self.ops[ops_in_python_index] + else: + ops_in_cpp_index += 1 + ops_in_python_index += 1 + assert len(self.ops) == len(ops_in_cpp) for index in range(len(self.ops)): assert self.ops[index].desc == ops_in_cpp[index] diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index da85786d0c..e4cf4a8bce 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -14,6 +14,7 @@ import unittest import paddle.fluid.core as core +from paddle.fluid.framework import Program class TestOpDesc(unittest.TestCase): @@ -187,32 +188,46 @@ class TestBlockDesc(unittest.TestCase): self.assertEqual(all_ops, [op0, op1, op2]) def test_remove_op(self): - prog = core.ProgramDesc() + program = Program() + prog = program.desc self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) + + op0 = block.append_op() op1 = block.append_op() op2 = block.append_op() + op0.set_type("test") + op1.set_type("test") + op2.set_type("test") + + var0 = block.var("var0") var1 = block.var("var1") var2 = block.var("var2") var3 = block.var("var3") var4 = block.var("var4") var5 = block.var("var5") + + op0.set_input("X", ["var0"]) + op0.set_output("Y", ["var0"]) op1.set_input("X", ["var1", "var2"]) op1.set_output("Y", ["var3", "var4"]) op2.set_input("X", ["var1"]) op2.set_output("Y", ["var4", "var5"]) + program.sync_with_cpp() + # remove op1, its input var2 and output var3 will be removed at the same time, # but its input var1 and output var4 will not be removed since they are used for op2. - block.remove_op(0, 1) + block.remove_op(1, 2) + program.sync_with_cpp() all_ops = [] for idx in xrange(0, block.op_size()): all_ops.append(block.op(idx)) - self.assertEqual(all_ops, [op2]) + self.assertEqual(all_ops, [op0, op2]) all_vars = block.all_vars() - self.assertEqual(set(all_vars), {var1, var4, var5}) + self.assertEqual(set(all_vars), {var0, var1, var4, var5}) if __name__ == '__main__': -- GitLab