未验证 提交 660e4106 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #15855 from dzhwinter/fix/nightly_test

accelerate memory optimize process
......@@ -461,11 +461,21 @@ void ControlFlowGraph::LiveVariableAnalysis() {
}
}
}
for (auto* op : ops_) {
unlived_vars_[op] = std::set<std::string>();
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
}
void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node,
int begin_idx) {
std::vector<bool> need_update(ops_.size(), false);
// update graph from begin idx to the end
for (size_t i = begin_idx; i != ops_.size(); ++i) {
auto* op = ops_[i];
......@@ -480,15 +490,27 @@ void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
if (live_in_[op].find(old_node) != live_in_[op].end()) {
live_in_[op].erase(old_node);
live_in_[op].insert(new_node);
need_update[i] = true;
}
if (live_out_[op].find(old_node) != live_out_[op].end()) {
live_out_[op].erase(old_node);
live_out_[op].insert(new_node);
need_update[i] = true;
}
}
for (size_t i = begin_idx; i < ops_.size(); ++i) {
if (!need_update[i]) continue;
auto* op = ops_[i];
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
}
const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const {
const std::set<std::string>& ControlFlowGraph::LiveIn(ir::Node* op) const {
auto it = live_in_.find(op);
PADDLE_ENFORCE(
it != live_in_.end(),
......@@ -496,7 +518,7 @@ const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const {
return it->second;
}
const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const {
const std::set<std::string>& ControlFlowGraph::LiveOut(ir::Node* op) const {
auto it = live_out_.find(op);
PADDLE_ENFORCE(
it != live_out_.end(),
......@@ -504,15 +526,24 @@ const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const {
return it->second;
}
const std::set<std::string> ControlFlowGraph::Use(ir::Node* op) const {
const std::set<std::string>& ControlFlowGraph::Use(ir::Node* op) const {
auto it = uses_.find(op);
PADDLE_ENFORCE(
it != uses_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
string::Sprintf("Expect %s in use, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::Unlived(ir::Node* op) const {
auto it = unlived_vars_.find(op);
PADDLE_ENFORCE(
it != unlived_vars_.end(),
string::Sprintf("Expect %s in unlived_set, but Not Found.", op->Name()));
return it->second;
return it->second;
}
const std::vector<ir::Node*> ControlFlowGraph::Ops() const { return ops_; }
const std::vector<ir::Node*>& ControlFlowGraph::Ops() const { return ops_; }
std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; }
......
......@@ -92,10 +92,11 @@ class ControlFlowGraph {
void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx);
const std::set<std::string> LiveIn(ir::Node* op) const;
const std::set<std::string> LiveOut(ir::Node* op) const;
const std::set<std::string> Use(ir::Node* op) const;
const std::vector<ir::Node*> Ops() const;
const std::set<std::string>& LiveIn(ir::Node* op) const;
const std::set<std::string>& LiveOut(ir::Node* op) const;
const std::set<std::string>& Use(ir::Node* op) const;
const std::set<std::string>& Unlived(ir::Node* op) const;
const std::vector<ir::Node*>& Ops() const;
std::vector<ir::Node*>& Ops();
// for ssa-graph nodes
......@@ -117,6 +118,7 @@ class ControlFlowGraph {
VarSetMap live_out_;
VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs
std::unordered_map<ir::Node*, std::set<std::string>> unlived_vars_;
std::vector<ir::Node*> ops_; // op sequence by topology sort
};
......
......@@ -118,13 +118,11 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
// fill the pool
for (auto var : cfg_->LiveIn(op)) {
if (cfg_->LiveOut(op).count(var) == 0) {
ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (var_node == nullptr || var_node->IsCtrlVar()) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node);
}
for (auto& var : cfg_->Unlived(op)) {
ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (var_node == nullptr || var_node->IsCtrlVar()) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册