diff --git a/paddle/fluid/framework/details/analysis_var_pass.cc b/paddle/fluid/framework/details/analysis_var_pass.cc index 223b9da3cfba33fc32d1334cddccb9f503bd0bef..c6a9d08f7378bb1df4452bd250bbaf3d4a961c66 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.cc +++ b/paddle/fluid/framework/details/analysis_var_pass.cc @@ -79,8 +79,7 @@ void FilterVariables(const Container& nodes, Callback callback) { std::unique_ptr AnalysisVarPass::ApplyImpl( std::unique_ptr graph) const { auto nodes = graph->Nodes(); - auto subblock_vars = GetSubBlockVars(nodes); - skip_set_.insert(subblock_vars.begin(), subblock_vars.end()); + CollectSkipSet(nodes); cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_->LiveVariableAnalysis(); @@ -247,20 +246,21 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { } } -std::unordered_set AnalysisVarPass::GetSubBlockVars( +void AnalysisVarPass::CollectSkipSet( const std::unordered_set& nodes) const { - std::unordered_set vars; + auto update_skip_set = [&](OpDesc* op_desc) { + auto inputs = op_desc->InputArgumentNames(); + auto outputs = op_desc->OutputArgumentNames(); + skip_set_.insert(inputs.begin(), inputs.end()); + skip_set_.insert(outputs.begin(), outputs.end()); + }; for (auto& op : nodes) { if (!op->IsOp() || op->Op() == nullptr) continue; auto* op_desc = op->Op(); - if (OpHasSubBlock(op_desc)) { - auto inputs = op_desc->InputArgumentNames(); - auto outputs = op_desc->OutputArgumentNames(); - vars.insert(inputs.begin(), inputs.end()); - vars.insert(outputs.begin(), outputs.end()); - } + if (OpHasSubBlock(op_desc)) update_skip_set(op_desc); + if (op_desc->Type() == "send") update_skip_set(op_desc); + if (op_desc->Type() == "recv") update_skip_set(op_desc); } - return vars; } void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, diff --git a/paddle/fluid/framework/details/analysis_var_pass.h b/paddle/fluid/framework/details/analysis_var_pass.h index 144204beafb341351172c29e3b4cd41db49be6f9..007bdd831182e73049f9d2f44c7cb44520c9862e 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.h +++ b/paddle/fluid/framework/details/analysis_var_pass.h @@ -60,8 +60,8 @@ class AnalysisVarPass : public ir::Pass { // valid a tensor can be reuse or not bool NodeCanReused(ir::Node* node) const; // scan subblock and collect the output/input variables. - std::unordered_set GetSubBlockVars( - const std::unordered_set&) const; + // scan the dist 'send', 'recv' op inputs/outputs + void CollectSkipSet(const std::unordered_set&) const; // check op has subblock or not bool OpHasSubBlock(OpDesc* desc) const;