From c4ded17e8cbcbf33e68145c1a4ffe777582bf3ab Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 11 Feb 2019 09:19:48 +0800 Subject: [PATCH] async mode support dist train --- paddle/fluid/framework/details/build_strategy.cc | 6 +++--- paddle/fluid/framework/details/multi_devices_graph_pass.cc | 7 ++++++- paddle/fluid/framework/parallel_executor.cc | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index a286cb30a..e91739525 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -133,10 +133,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { void AppendMultiDevPass(const BuildStrategy &strategy) { ir::Pass *multi_devices_pass; - if (strategy_.async_mode_) { - multi_devices_pass = AppendPass("async_multi_devices_pass").get(); - } else if (strategy_.is_distribution_) { + if (strategy_.is_distribution_) { multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); + } else if (strategy_.async_mode_) { + multi_devices_pass = AppendPass("async_multi_devices_pass").get(); } else { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { multi_devices_pass = diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index a2bbfc91b..572d374b5 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -756,6 +756,11 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, insert_op = true; need_broadcast_var_ = true; } else if (OpHaveRole(*node, OpRole::kDist)) { + // in async_mode, each graph will send it's own gradient, do not need to + // merge gradient. + if (strategy_.async_mode_ && node->Op()->Type() != "concat") { + return false; + } int op_dev_id = CreateDistTrainOp(result, node); if (node->Op()->Type() == "concat") { // the input(block of parameter) of concat is on different device, @@ -827,7 +832,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { } auto recv_param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); - if (recv_param_grad.size() == 2U) { + if (recv_param_grad.size() == 2U && !strategy_.async_mode_) { op_dev_id = GetVarDeviceID(recv_param_grad[1]); VLOG(10) << "recv param " << recv_param_grad[0] << " get grad place: " << recv_param_grad[1] diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f0bc3accc..c85fe4f20 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -283,7 +283,7 @@ ParallelExecutor::ParallelExecutor( graphs.push_back(std::move(graph)); } #else - if (build_strategy.async_mode_) { + if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { for (size_t i = 0; i < member_->places_.size(); ++i) { std::unique_ptr graph = build_strategy.Apply( main_program, {member_->places_[i]}, loss_var_name, -- GitLab