提交 a71f2fbe 编写于 作者: D dzhwinter

fix default value. test=develop

上级 c6bd434f
...@@ -461,11 +461,21 @@ void ControlFlowGraph::LiveVariableAnalysis() { ...@@ -461,11 +461,21 @@ void ControlFlowGraph::LiveVariableAnalysis() {
} }
} }
} }
for (auto* op : ops_) {
unlived_vars_[op] = std::set<std::string>();
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
} }
void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node, void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, const std::string& new_node,
int begin_idx) { int begin_idx) {
std::vector<bool> need_update(ops_.size(), false);
// update graph from begin idx to the end // update graph from begin idx to the end
for (size_t i = begin_idx; i != ops_.size(); ++i) { for (size_t i = begin_idx; i != ops_.size(); ++i) {
auto* op = ops_[i]; auto* op = ops_[i];
...@@ -480,15 +490,27 @@ void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node, ...@@ -480,15 +490,27 @@ void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
if (live_in_[op].find(old_node) != live_in_[op].end()) { if (live_in_[op].find(old_node) != live_in_[op].end()) {
live_in_[op].erase(old_node); live_in_[op].erase(old_node);
live_in_[op].insert(new_node); live_in_[op].insert(new_node);
need_update[i] = true;
} }
if (live_out_[op].find(old_node) != live_out_[op].end()) { if (live_out_[op].find(old_node) != live_out_[op].end()) {
live_out_[op].erase(old_node); live_out_[op].erase(old_node);
live_out_[op].insert(new_node); live_out_[op].insert(new_node);
need_update[i] = true;
}
}
for (size_t i = begin_idx; i < ops_.size(); ++i) {
if (!need_update[i]) continue;
auto* op = ops_[i];
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
} }
} }
} }
const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const { const std::set<std::string>& ControlFlowGraph::LiveIn(ir::Node* op) const {
auto it = live_in_.find(op); auto it = live_in_.find(op);
PADDLE_ENFORCE( PADDLE_ENFORCE(
it != live_in_.end(), it != live_in_.end(),
...@@ -496,7 +518,7 @@ const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const { ...@@ -496,7 +518,7 @@ const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const {
return it->second; return it->second;
} }
const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const { const std::set<std::string>& ControlFlowGraph::LiveOut(ir::Node* op) const {
auto it = live_out_.find(op); auto it = live_out_.find(op);
PADDLE_ENFORCE( PADDLE_ENFORCE(
it != live_out_.end(), it != live_out_.end(),
...@@ -504,15 +526,24 @@ const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const { ...@@ -504,15 +526,24 @@ const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const {
return it->second; return it->second;
} }
const std::set<std::string> ControlFlowGraph::Use(ir::Node* op) const { const std::set<std::string>& ControlFlowGraph::Use(ir::Node* op) const {
auto it = uses_.find(op); auto it = uses_.find(op);
PADDLE_ENFORCE( PADDLE_ENFORCE(
it != uses_.end(), it != uses_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name())); string::Sprintf("Expect %s in use, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::Unlived(ir::Node* op) const {
auto it = unlived_vars_.find(op);
PADDLE_ENFORCE(
it != unlived_vars_.end(),
string::Sprintf("Expect %s in unlived_set, but Not Found.", op->Name()));
return it->second;
return it->second; return it->second;
} }
const std::vector<ir::Node*> ControlFlowGraph::Ops() const { return ops_; } const std::vector<ir::Node*>& ControlFlowGraph::Ops() const { return ops_; }
std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; } std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; }
......
...@@ -92,10 +92,11 @@ class ControlFlowGraph { ...@@ -92,10 +92,11 @@ class ControlFlowGraph {
void RenameVarInCFGGraph(const std::string& old_node, void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx); const std::string& new_node, int begin_idx);
const std::set<std::string> LiveIn(ir::Node* op) const; const std::set<std::string>& LiveIn(ir::Node* op) const;
const std::set<std::string> LiveOut(ir::Node* op) const; const std::set<std::string>& LiveOut(ir::Node* op) const;
const std::set<std::string> Use(ir::Node* op) const; const std::set<std::string>& Use(ir::Node* op) const;
const std::vector<ir::Node*> Ops() const; const std::set<std::string>& Unlived(ir::Node* op) const;
const std::vector<ir::Node*>& Ops() const;
std::vector<ir::Node*>& Ops(); std::vector<ir::Node*>& Ops();
// for ssa-graph nodes // for ssa-graph nodes
...@@ -117,6 +118,7 @@ class ControlFlowGraph { ...@@ -117,6 +118,7 @@ class ControlFlowGraph {
VarSetMap live_out_; VarSetMap live_out_;
VarSetMap uses_; // op inputs VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs VarSetMap defs_; // op outputs
std::unordered_map<ir::Node*, std::set<std::string>> unlived_vars_;
std::vector<ir::Node*> ops_; // op sequence by topology sort std::vector<ir::Node*> ops_; // op sequence by topology sort
}; };
......
...@@ -118,8 +118,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -118,8 +118,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
} }
// fill the pool // fill the pool
for (auto var : cfg_->LiveIn(op)) { for (auto& var : cfg_->Unlived(op)) {
if (cfg_->LiveOut(op).count(var) == 0) {
ir::Node* var_node = cfg_->GetNodeByName(var, op); ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (var_node == nullptr || var_node->IsCtrlVar()) continue; if (var_node == nullptr || var_node->IsCtrlVar()) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) { if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
...@@ -127,7 +126,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -127,7 +126,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
} }
} }
}
graph->ResolveHazard(var_nodes_); graph->ResolveHazard(var_nodes_);
return graph; return graph;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册