提交 cca71532 编写于 作者: D dzhwinter

add skip send.recv test=develop

上级 546eefae
...@@ -79,8 +79,7 @@ void FilterVariables(const Container& nodes, Callback callback) { ...@@ -79,8 +79,7 @@ void FilterVariables(const Container& nodes, Callback callback) {
std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl( std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
auto subblock_vars = GetSubBlockVars(nodes); CollectSkipSet(nodes);
skip_set_.insert(subblock_vars.begin(), subblock_vars.end());
cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_.reset(new details::ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis(); cfg_->LiveVariableAnalysis();
...@@ -247,20 +246,21 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { ...@@ -247,20 +246,21 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const {
} }
} }
std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars( void AnalysisVarPass::CollectSkipSet(
const std::unordered_set<ir::Node*>& nodes) const { const std::unordered_set<ir::Node*>& nodes) const {
std::unordered_set<std::string> vars; auto update_skip_set = [&](OpDesc* op_desc) {
auto inputs = op_desc->InputArgumentNames();
auto outputs = op_desc->OutputArgumentNames();
skip_set_.insert(inputs.begin(), inputs.end());
skip_set_.insert(outputs.begin(), outputs.end());
};
for (auto& op : nodes) { for (auto& op : nodes) {
if (!op->IsOp() || op->Op() == nullptr) continue; if (!op->IsOp() || op->Op() == nullptr) continue;
auto* op_desc = op->Op(); auto* op_desc = op->Op();
if (OpHasSubBlock(op_desc)) { if (OpHasSubBlock(op_desc)) update_skip_set(op_desc);
auto inputs = op_desc->InputArgumentNames(); if (op_desc->Type() == "send") update_skip_set(op_desc);
auto outputs = op_desc->OutputArgumentNames(); if (op_desc->Type() == "recv") update_skip_set(op_desc);
vars.insert(inputs.begin(), inputs.end());
vars.insert(outputs.begin(), outputs.end());
}
} }
return vars;
} }
void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var,
......
...@@ -60,8 +60,8 @@ class AnalysisVarPass : public ir::Pass { ...@@ -60,8 +60,8 @@ class AnalysisVarPass : public ir::Pass {
// valid a tensor can be reuse or not // valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node) const; bool NodeCanReused(ir::Node* node) const;
// scan subblock and collect the output/input variables. // scan subblock and collect the output/input variables.
std::unordered_set<std::string> GetSubBlockVars( // scan the dist 'send', 'recv' op inputs/outputs
const std::unordered_set<ir::Node*>&) const; void CollectSkipSet(const std::unordered_set<ir::Node*>&) const;
// check op has subblock or not // check op has subblock or not
bool OpHasSubBlock(OpDesc* desc) const; bool OpHasSubBlock(OpDesc* desc) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册