提交 7f401224 编写于 作者: L Luo Tao

adjust remove rule for variables

上级 ccfec1bc
...@@ -147,40 +147,51 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { ...@@ -147,40 +147,51 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return; return;
} }
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op,
std::vector<std::string> &v) {
auto in_names = (*op)->InputArgumentNames();
v.insert(v.end(), in_names.begin(), in_names.end());
auto out_names = (*op)->OutputArgumentNames();
v.insert(v.end(), out_names.begin(), out_names.end());
std::sort(v.begin(), v.end());
auto last = std::unique(v.begin(), v.end());
v.erase(last, v.end());
};
need_update_ = true; need_update_ = true;
std::vector<std::string> vars1; // input vars from delete ops
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
// delete all output vars
auto out_names = (*it)->OutputArgumentNames();
for (auto n : out_names) {
vars_.erase(vars_.find(n));
}
// collect all input vars from remove ops
auto in_names = (*it)->InputArgumentNames();
vars1.insert(vars1.end(), in_names.begin(), in_names.end());
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
// collect input and output vars from remain ops for (size_t i = s; i < e; i++) {
std::vector<std::string> vars2; // since remove op one by one, every time remove the first op.
auto op = ops_.begin() + s;
// collect input and output variables from current delete op
std::vector<std::string> cur_vars;
get_vars(op, cur_vars);
// remove current op
ops_.erase(ops_.begin() + s);
// collect input and output variables from other ops
std::vector<std::string> other_vars;
for (auto it = ops_.begin(); it != ops_.end(); it++) { for (auto it = ops_.begin(); it != ops_.end(); it++) {
auto in_names = (*it)->InputArgumentNames(); get_vars(it, other_vars);
auto out_names = (*it)->OutputArgumentNames();
vars2.insert(vars2.end(), in_names.begin(), in_names.end());
vars2.insert(vars2.end(), out_names.begin(), out_names.end());
} }
// delete input vars if no other op use it. // variables should be deleted
std::vector<std::string> del_vars; std::vector<std::string> delete_vars;
std::sort(vars1.begin(), vars1.end()); // delete_vars = cur_vars - cur_vars ^ other_input_vars
std::unique(vars1.begin(), vars1.end()); std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
std::sort(vars2.begin(), vars2.end()); other_vars.end(),
std::unique(vars2.begin(), vars2.end()); std::inserter(delete_vars, delete_vars.end()));
// del_vars = vars1 - vars1 ^ vars2 // remove variables
std::set_difference(vars1.begin(), vars1.end(), vars2.begin(), vars2.end(), for (size_t i = 0; i < delete_vars.size(); i++) {
std::inserter(del_vars, del_vars.end())); auto name = delete_vars[i];
for (auto it = del_vars.begin(); it != del_vars.end(); it++) { auto it = vars_.find(name);
vars_.erase(vars_.find(*it)); PADDLE_ENFORCE(it != vars_.end(),
"%s is not in variable list, it should not be deleted",
name);
vars_.erase(it);
VLOG(3) << "deleting variable " << name;
}
} }
} }
......
...@@ -89,6 +89,11 @@ class BlockDesc { ...@@ -89,6 +89,11 @@ class BlockDesc {
OpDesc *InsertOp(size_t index); OpDesc *InsertOp(size_t index);
/*
* Remove Op and its input/output variables.
* Note that for either input or ouput variable, if it is also an input or
* output variable of other ops, we should remain it.
*/
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
std::vector<OpDesc *> AllOps() const; std::vector<OpDesc *> AllOps() const;
......
...@@ -197,13 +197,14 @@ class TestBlockDesc(unittest.TestCase): ...@@ -197,13 +197,14 @@ class TestBlockDesc(unittest.TestCase):
var2 = block.var("var2") var2 = block.var("var2")
var3 = block.var("var3") var3 = block.var("var3")
var4 = block.var("var4") var4 = block.var("var4")
var5 = block.var("var5")
op1.set_input("X", ["var1", "var2"]) op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3"]) op1.set_output("Y", ["var3", "var4"])
op2.set_input("X", ["var1"]) op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4"]) op2.set_output("Y", ["var4", "var5"])
# remove op1, its input var2 and output var3 will be removed at the same time, # remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 will not be removed since var1 is also an input for op2. # but its input var1 and output var4 will not be removed since they are used for op2.
block.remove_op(0, 1) block.remove_op(0, 1)
all_ops = [] all_ops = []
...@@ -211,7 +212,7 @@ class TestBlockDesc(unittest.TestCase): ...@@ -211,7 +212,7 @@ class TestBlockDesc(unittest.TestCase):
all_ops.append(block.op(idx)) all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op2]) self.assertEqual(all_ops, [op2])
all_vars = block.all_vars() all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var1, var4}) self.assertEqual(set(all_vars), {var1, var4, var5})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册