diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index 4faf9dcf375e28fbaf64b099a77fb761201ed9a3..fbe08349c37c4fde115ceea954ba2b84880088d7 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -147,40 +147,51 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { return; } + auto get_vars = [](std::deque>::iterator &op, + std::vector &v) { + auto in_names = (*op)->InputArgumentNames(); + v.insert(v.end(), in_names.begin(), in_names.end()); + auto out_names = (*op)->OutputArgumentNames(); + v.insert(v.end(), out_names.begin(), out_names.end()); + std::sort(v.begin(), v.end()); + auto last = std::unique(v.begin(), v.end()); + v.erase(last, v.end()); + }; need_update_ = true; - std::vector vars1; // input vars from delete ops - for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) { - // delete all output vars - auto out_names = (*it)->OutputArgumentNames(); - for (auto n : out_names) { - vars_.erase(vars_.find(n)); + + for (size_t i = s; i < e; i++) { + // since remove op one by one, every time remove the first op. + auto op = ops_.begin() + s; + + // collect input and output variables from current delete op + std::vector cur_vars; + get_vars(op, cur_vars); + + // remove current op + ops_.erase(ops_.begin() + s); + + // collect input and output variables from other ops + std::vector other_vars; + for (auto it = ops_.begin(); it != ops_.end(); it++) { + get_vars(it, other_vars); } - // collect all input vars from remove ops - auto in_names = (*it)->InputArgumentNames(); - vars1.insert(vars1.end(), in_names.begin(), in_names.end()); - } - ops_.erase(ops_.begin() + s, ops_.begin() + e); - - // collect input and output vars from remain ops - std::vector vars2; - for (auto it = ops_.begin(); it != ops_.end(); it++) { - auto in_names = (*it)->InputArgumentNames(); - auto out_names = (*it)->OutputArgumentNames(); - vars2.insert(vars2.end(), in_names.begin(), in_names.end()); - vars2.insert(vars2.end(), out_names.begin(), out_names.end()); - } - // delete input vars if no other op use it. - std::vector del_vars; - std::sort(vars1.begin(), vars1.end()); - std::unique(vars1.begin(), vars1.end()); - std::sort(vars2.begin(), vars2.end()); - std::unique(vars2.begin(), vars2.end()); - // del_vars = vars1 - vars1 ^ vars2 - std::set_difference(vars1.begin(), vars1.end(), vars2.begin(), vars2.end(), - std::inserter(del_vars, del_vars.end())); - for (auto it = del_vars.begin(); it != del_vars.end(); it++) { - vars_.erase(vars_.find(*it)); + // variables should be deleted + std::vector delete_vars; + // delete_vars = cur_vars - cur_vars ^ other_input_vars + std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(), + other_vars.end(), + std::inserter(delete_vars, delete_vars.end())); + // remove variables + for (size_t i = 0; i < delete_vars.size(); i++) { + auto name = delete_vars[i]; + auto it = vars_.find(name); + PADDLE_ENFORCE(it != vars_.end(), + "%s is not in variable list, it should not be deleted", + name); + vars_.erase(it); + VLOG(3) << "deleting variable " << name; + } } } diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 185f018ac1b5863e0ee86fdaa17df1ccbc6e030e..468423e0e8e7b8c9ebc14b7568c9c3bd21645ea7 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -89,6 +89,11 @@ class BlockDesc { OpDesc *InsertOp(size_t index); + /* + * Remove Op and its input/output variables. + * Note that for either input or ouput variable, if it is also an input or + * output variable of other ops, we should remain it. + */ void RemoveOp(size_t s, size_t e); std::vector AllOps() const; diff --git a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py index 871cb76fffcb738f1f44f0be4fef1ff166f9adf7..da85786d0c085a4e97d9ac272feed251296ad52d 100644 --- a/python/paddle/fluid/tests/unittests/test_protobuf_descs.py +++ b/python/paddle/fluid/tests/unittests/test_protobuf_descs.py @@ -197,13 +197,14 @@ class TestBlockDesc(unittest.TestCase): var2 = block.var("var2") var3 = block.var("var3") var4 = block.var("var4") + var5 = block.var("var5") op1.set_input("X", ["var1", "var2"]) - op1.set_output("Y", ["var3"]) + op1.set_output("Y", ["var3", "var4"]) op2.set_input("X", ["var1"]) - op2.set_output("Y", ["var4"]) + op2.set_output("Y", ["var4", "var5"]) # remove op1, its input var2 and output var3 will be removed at the same time, - # but its input var1 will not be removed since var1 is also an input for op2. + # but its input var1 and output var4 will not be removed since they are used for op2. block.remove_op(0, 1) all_ops = [] @@ -211,7 +212,7 @@ class TestBlockDesc(unittest.TestCase): all_ops.append(block.op(idx)) self.assertEqual(all_ops, [op2]) all_vars = block.all_vars() - self.assertEqual(set(all_vars), {var1, var4}) + self.assertEqual(set(all_vars), {var1, var4, var5}) if __name__ == '__main__':