提交 ccfec1bc 编写于 作者: L Luo Tao

remove vars when remove ops

上级 54a85b7b
...@@ -148,14 +148,40 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { ...@@ -148,14 +148,40 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
return; return;
} }
need_update_ = true; need_update_ = true;
std::vector<std::string> vars1; // input vars from delete ops
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) { for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
auto names = (*it)->InputArgumentNames(); // delete all output vars
for (auto n : names) { auto out_names = (*it)->OutputArgumentNames();
// TODO(typhoonzero): delete vars if no other op use it. for (auto n : out_names) {
VLOG(3) << "deleting var " << n; vars_.erase(vars_.find(n));
} }
// collect all input vars from remove ops
auto in_names = (*it)->InputArgumentNames();
vars1.insert(vars1.end(), in_names.begin(), in_names.end());
} }
ops_.erase(ops_.begin() + s, ops_.begin() + e); ops_.erase(ops_.begin() + s, ops_.begin() + e);
// collect input and output vars from remain ops
std::vector<std::string> vars2;
for (auto it = ops_.begin(); it != ops_.end(); it++) {
auto in_names = (*it)->InputArgumentNames();
auto out_names = (*it)->OutputArgumentNames();
vars2.insert(vars2.end(), in_names.begin(), in_names.end());
vars2.insert(vars2.end(), out_names.begin(), out_names.end());
}
// delete input vars if no other op use it.
std::vector<std::string> del_vars;
std::sort(vars1.begin(), vars1.end());
std::unique(vars1.begin(), vars1.end());
std::sort(vars2.begin(), vars2.end());
std::unique(vars2.begin(), vars2.end());
// del_vars = vars1 - vars1 ^ vars2
std::set_difference(vars1.begin(), vars1.end(), vars2.begin(), vars2.end(),
std::inserter(del_vars, del_vars.end()));
for (auto it = del_vars.begin(); it != del_vars.end(); it++) {
vars_.erase(vars_.find(*it));
}
} }
std::vector<OpDesc *> BlockDesc::AllOps() const { std::vector<OpDesc *> BlockDesc::AllOps() const {
......
...@@ -186,6 +186,33 @@ class TestBlockDesc(unittest.TestCase): ...@@ -186,6 +186,33 @@ class TestBlockDesc(unittest.TestCase):
all_ops.append(block.op(idx)) all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op0, op1, op2]) self.assertEqual(all_ops, [op0, op1, op2])
def test_remove_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
op1 = block.append_op()
op2 = block.append_op()
var1 = block.var("var1")
var2 = block.var("var2")
var3 = block.var("var3")
var4 = block.var("var4")
op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3"])
op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4"])
# remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 will not be removed since var1 is also an input for op2.
block.remove_op(0, 1)
all_ops = []
for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op2])
all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var4})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册