提交 103407aa 编写于 作者: L Luo Tao

refine sync_with_cpp when remove ops or remove vars

上级 faa752a4
...@@ -847,6 +847,11 @@ class Block(object): ...@@ -847,6 +847,11 @@ class Block(object):
if not self.has_var(var.name()): if not self.has_var(var.name()):
self.create_var(name=var.name(), desc=var, type=var.type()) self.create_var(name=var.name(), desc=var, type=var.type())
# sync variables removed from c++ end
for var in self.vars.keys():
if not self.desc.find_var(var):
self.vars.pop(var)
# sync operators from cpp # sync operators from cpp
ops_in_cpp = [] ops_in_cpp = []
for op_idx in range(0, self.desc.op_size()): for op_idx in range(0, self.desc.op_size()):
...@@ -881,6 +886,19 @@ class Block(object): ...@@ -881,6 +886,19 @@ class Block(object):
op = Operator(self, op_desc) op = Operator(self, op_desc)
self.ops.append(op) self.ops.append(op)
# sync ops removed from c++ end
if end_index != -1 and end_index < len(self.ops):
ops_in_cpp_index = 0
ops_in_python_index = 0
while ops_in_python_index < len(
self.ops) and ops_in_cpp_index < len(ops_in_cpp):
if self.ops[ops_in_python_index].desc != ops_in_cpp[
ops_in_cpp_index]:
del self.ops[ops_in_python_index]
else:
ops_in_cpp_index += 1
ops_in_python_index += 1
assert len(self.ops) == len(ops_in_cpp) assert len(self.ops) == len(ops_in_cpp)
for index in range(len(self.ops)): for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index] assert self.ops[index].desc == ops_in_cpp[index]
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import Program
class TestOpDesc(unittest.TestCase): class TestOpDesc(unittest.TestCase):
...@@ -187,32 +188,46 @@ class TestBlockDesc(unittest.TestCase): ...@@ -187,32 +188,46 @@ class TestBlockDesc(unittest.TestCase):
self.assertEqual(all_ops, [op0, op1, op2]) self.assertEqual(all_ops, [op0, op1, op2])
def test_remove_op(self): def test_remove_op(self):
prog = core.ProgramDesc() program = Program()
prog = program.desc
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
op0 = block.append_op()
op1 = block.append_op() op1 = block.append_op()
op2 = block.append_op() op2 = block.append_op()
op0.set_type("test")
op1.set_type("test")
op2.set_type("test")
var0 = block.var("var0")
var1 = block.var("var1") var1 = block.var("var1")
var2 = block.var("var2") var2 = block.var("var2")
var3 = block.var("var3") var3 = block.var("var3")
var4 = block.var("var4") var4 = block.var("var4")
var5 = block.var("var5") var5 = block.var("var5")
op0.set_input("X", ["var0"])
op0.set_output("Y", ["var0"])
op1.set_input("X", ["var1", "var2"]) op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3", "var4"]) op1.set_output("Y", ["var3", "var4"])
op2.set_input("X", ["var1"]) op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4", "var5"]) op2.set_output("Y", ["var4", "var5"])
program.sync_with_cpp()
# remove op1, its input var2 and output var3 will be removed at the same time, # remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 and output var4 will not be removed since they are used for op2. # but its input var1 and output var4 will not be removed since they are used for op2.
block.remove_op(0, 1) block.remove_op(1, 2)
program.sync_with_cpp()
all_ops = [] all_ops = []
for idx in xrange(0, block.op_size()): for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx)) all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op2]) self.assertEqual(all_ops, [op0, op2])
all_vars = block.all_vars() all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var4, var5}) self.assertEqual(set(all_vars), {var0, var1, var4, var5})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册