未验证 提交 95710456 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9600 from luotao1/sync_with_cpp

refine sync_with_cpp when remove ops or remove vars
......@@ -847,6 +847,11 @@ class Block(object):
if not self.has_var(var.name()):
self.create_var(name=var.name(), desc=var, type=var.type())
# sync variables removed from c++ end
for var in self.vars.keys():
if not self.desc.find_var(var):
self.vars.pop(var)
# sync operators from cpp
ops_in_cpp = []
for op_idx in range(0, self.desc.op_size()):
......@@ -881,6 +886,19 @@ class Block(object):
op = Operator(self, op_desc)
self.ops.append(op)
# sync ops removed from c++ end
if end_index != -1 and end_index < len(self.ops):
ops_in_cpp_index = 0
ops_in_python_index = 0
while ops_in_python_index < len(
self.ops) and ops_in_cpp_index < len(ops_in_cpp):
if self.ops[ops_in_python_index].desc != ops_in_cpp[
ops_in_cpp_index]:
del self.ops[ops_in_python_index]
else:
ops_in_cpp_index += 1
ops_in_python_index += 1
assert len(self.ops) == len(ops_in_cpp)
for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index]
......
......@@ -14,6 +14,7 @@
import unittest
import paddle.fluid.core as core
from paddle.fluid.framework import Program
class TestOpDesc(unittest.TestCase):
......@@ -187,32 +188,46 @@ class TestBlockDesc(unittest.TestCase):
self.assertEqual(all_ops, [op0, op1, op2])
def test_remove_op(self):
prog = core.ProgramDesc()
program = Program()
prog = program.desc
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op0 = block.append_op()
op1 = block.append_op()
op2 = block.append_op()
op0.set_type("test")
op1.set_type("test")
op2.set_type("test")
var0 = block.var("var0")
var1 = block.var("var1")
var2 = block.var("var2")
var3 = block.var("var3")
var4 = block.var("var4")
var5 = block.var("var5")
op0.set_input("X", ["var0"])
op0.set_output("Y", ["var0"])
op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3", "var4"])
op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4", "var5"])
program.sync_with_cpp()
# remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 and output var4 will not be removed since they are used for op2.
block.remove_op(0, 1)
block.remove_op(1, 2)
program.sync_with_cpp()
all_ops = []
for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op2])
self.assertEqual(all_ops, [op0, op2])
all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var4, var5})
self.assertEqual(set(all_vars), {var0, var1, var4, var5})
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册