diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index a286cb30a239a430b2f9e035693a04f93b5d3b17..e917395259cfba6d205101087fd44c6c46b24ee7 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -133,10 +133,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { void AppendMultiDevPass(const BuildStrategy &strategy) { ir::Pass *multi_devices_pass; - if (strategy_.async_mode_) { - multi_devices_pass = AppendPass("async_multi_devices_pass").get(); - } else if (strategy_.is_distribution_) { + if (strategy_.is_distribution_) { multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); + } else if (strategy_.async_mode_) { + multi_devices_pass = AppendPass("async_multi_devices_pass").get(); } else { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { multi_devices_pass = diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index a2bbfc91b7374523fd6f27d3fc65c9ef090600c1..572d374b501a767f22b12d6cc3e7ea3b28004a18 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -756,6 +756,11 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, insert_op = true; need_broadcast_var_ = true; } else if (OpHaveRole(*node, OpRole::kDist)) { + // in async_mode, each graph will send it's own gradient, do not need to + // merge gradient. + if (strategy_.async_mode_ && node->Op()->Type() != "concat") { + return false; + } int op_dev_id = CreateDistTrainOp(result, node); if (node->Op()->Type() == "concat") { // the input(block of parameter) of concat is on different device, @@ -827,7 +832,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { } auto recv_param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); - if (recv_param_grad.size() == 2U) { + if (recv_param_grad.size() == 2U && !strategy_.async_mode_) { op_dev_id = GetVarDeviceID(recv_param_grad[1]); VLOG(10) << "recv param " << recv_param_grad[0] << " get grad place: " << recv_param_grad[1] diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f0bc3acccc22e6b74ac87ff98357cfe157c675c8..c85fe4f2006817b81f3c7e5d1a1371fc90a90ab2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -283,7 +283,7 @@ ParallelExecutor::ParallelExecutor( graphs.push_back(std::move(graph)); } #else - if (build_strategy.async_mode_) { + if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { for (size_t i = 0; i < member_->places_.size(); ++i) { std::unique_ptr graph = build_strategy.Apply( main_program, {member_->places_[i]}, loss_var_name,