From ec6519e806add60ed1838572887e018b8f478c13 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 26 Mar 2019 17:48:42 +0800 Subject: [PATCH] Fix allreducedep bug (#16443) --- .../framework/details/all_reduce_deps_pass.cc | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.cc b/paddle/fluid/framework/details/all_reduce_deps_pass.cc index ff223e616f7..c084410864b 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.cc +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -52,13 +53,28 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( // Note that must assert topology sort is stable auto& ops = graph->Get>(kStaleProgramOpDescs); for (auto* op_desc : ops) { - auto outputs = op_desc->Outputs(); - for (auto& o_it : outputs) { - for (auto& v : o_it.second) { // values - vars[v] = order; + try { + bool is_bk_op = + static_cast(boost::get(op_desc->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kBackward)); + if (!is_bk_op) continue; + + auto backward_vars = + boost::get>(op_desc->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); + + auto outputs = op_desc->Outputs(); + for (auto& o_it : outputs) { + for (auto& v : o_it.second) { // values + vars[v] = order; + VLOG(1) << "in all_reduce_deps_pass:" << v; + } } + order++; + } catch (boost::bad_get e) { } - order++; } std::vector dist_ops; -- GitLab