From cca71532eb6be8de79842b2bf7ece2ba7d80521b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 31 Jan 2019 23:15:58 +0800 Subject: [PATCH] add skip send.recv test=develop --- .../framework/details/analysis_var_pass.cc | 22 +++++++++---------- .../framework/details/analysis_var_pass.h | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/details/analysis_var_pass.cc b/paddle/fluid/framework/details/analysis_var_pass.cc index 223b9da3cfb..c6a9d08f737 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.cc +++ b/paddle/fluid/framework/details/analysis_var_pass.cc @@ -79,8 +79,7 @@ void FilterVariables(const Container& nodes, Callback callback) { std::unique_ptr AnalysisVarPass::ApplyImpl( std::unique_ptr graph) const { auto nodes = graph->Nodes(); - auto subblock_vars = GetSubBlockVars(nodes); - skip_set_.insert(subblock_vars.begin(), subblock_vars.end()); + CollectSkipSet(nodes); cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_->LiveVariableAnalysis(); @@ -247,20 +246,21 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { } } -std::unordered_set AnalysisVarPass::GetSubBlockVars( +void AnalysisVarPass::CollectSkipSet( const std::unordered_set& nodes) const { - std::unordered_set vars; + auto update_skip_set = [&](OpDesc* op_desc) { + auto inputs = op_desc->InputArgumentNames(); + auto outputs = op_desc->OutputArgumentNames(); + skip_set_.insert(inputs.begin(), inputs.end()); + skip_set_.insert(outputs.begin(), outputs.end()); + }; for (auto& op : nodes) { if (!op->IsOp() || op->Op() == nullptr) continue; auto* op_desc = op->Op(); - if (OpHasSubBlock(op_desc)) { - auto inputs = op_desc->InputArgumentNames(); - auto outputs = op_desc->OutputArgumentNames(); - vars.insert(inputs.begin(), inputs.end()); - vars.insert(outputs.begin(), outputs.end()); - } + if (OpHasSubBlock(op_desc)) update_skip_set(op_desc); + if (op_desc->Type() == "send") update_skip_set(op_desc); + if (op_desc->Type() == "recv") update_skip_set(op_desc); } - return vars; } void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, diff --git a/paddle/fluid/framework/details/analysis_var_pass.h b/paddle/fluid/framework/details/analysis_var_pass.h index 144204beafb..007bdd83118 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.h +++ b/paddle/fluid/framework/details/analysis_var_pass.h @@ -60,8 +60,8 @@ class AnalysisVarPass : public ir::Pass { // valid a tensor can be reuse or not bool NodeCanReused(ir::Node* node) const; // scan subblock and collect the output/input variables. - std::unordered_set GetSubBlockVars( - const std::unordered_set&) const; + // scan the dist 'send', 'recv' op inputs/outputs + void CollectSkipSet(const std::unordered_set&) const; // check op has subblock or not bool OpHasSubBlock(OpDesc* desc) const; -- GitLab