diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index 3693bc25d81a8309df1a6ddf3d9b08d484596ea9..4faf9dcf375e28fbaf64b099a77fb761201ed9a3 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -148,14 +148,40 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { return; } need_update_ = true; + std::vector vars1; // input vars from delete ops for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) { - auto names = (*it)->InputArgumentNames(); - for (auto n : names) { - // TODO(typhoonzero): delete vars if no other op use it. - VLOG(3) << "deleting var " << n; + // delete all output vars + auto out_names = (*it)->OutputArgumentNames(); + for (auto n : out_names) { + vars_.erase(vars_.find(n)); } + // collect all input vars from remove ops + auto in_names = (*it)->InputArgumentNames(); + vars1.insert(vars1.end(), in_names.begin(), in_names.end()); } ops_.erase(ops_.begin() + s, ops_.begin() + e); + + // collect input and output vars from remain ops + std::vector vars2; + for (auto it = ops_.begin(); it != ops_.end(); it++) { + auto in_names = (*it)->InputArgumentNames(); + auto out_names = (*it)->OutputArgumentNames(); + vars2.insert(vars2.end(), in_names.begin(), in_names.end()); + vars2.insert(vars2.end(), out_names.begin(), out_names.end()); + } + + // delete input vars if no other op use it. + std::vector del_vars; + std::sort(vars1.begin(), vars1.end()); + std::unique(vars1.begin(), vars1.end()); + std::sort(vars2.begin(), vars2.end()); + std::unique(vars2.begin(), vars2.end()); + // del_vars = vars1 - vars1 ^ vars2 + std::set_difference(vars1.begin(), vars1.end(), vars2.begin(), vars2.end(), + std::inserter(del_vars, del_vars.end())); + for (auto it = del_vars.begin(); it != del_vars.end(); it++) { + vars_.erase(vars_.find(*it)); + } } std::vector BlockDesc::AllOps() const { diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index 309ea2b9b7ede442da3ac897ce8d1a4b9aa68233..871cb76fffcb738f1f44f0be4fef1ff166f9adf7 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -186,6 +186,33 @@ class TestBlockDesc(unittest.TestCase): all_ops.append(block.op(idx)) self.assertEqual(all_ops, [op0, op1, op2]) + def test_remove_op(self): + prog = core.ProgramDesc() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + op1 = block.append_op() + op2 = block.append_op() + var1 = block.var("var1") + var2 = block.var("var2") + var3 = block.var("var3") + var4 = block.var("var4") + op1.set_input("X", ["var1", "var2"]) + op1.set_output("Y", ["var3"]) + op2.set_input("X", ["var1"]) + op2.set_output("Y", ["var4"]) + + # remove op1, its input var2 and output var3 will be removed at the same time, + # but its input var1 will not be removed since var1 is also an input for op2. + block.remove_op(0, 1) + + all_ops = [] + for idx in xrange(0, block.op_size()): + all_ops.append(block.op(idx)) + self.assertEqual(all_ops, [op2]) + all_vars = block.all_vars() + self.assertEqual(set(all_vars), {var1, var4}) + if __name__ == '__main__': unittest.main()