diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 17b9c3920140557adc90c32c7e1dd635ac48eb6f..14b73b368117b4816e1aeee8bb5c73f64257c91e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( const ProgramDesc &program) const { std::vector send_vars; + // since parameters are all in block 0, + // it's enough to only scan send ops in block 0 for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send_vars" || op->Type() == "send") { + // TODO(Yancey1989): use a graceful method to find send op, + // instead of the the hard code string + if (op->Type() == "send_vars") { auto op_vars = op->InputArgumentNames(); send_vars.reserve(send_vars.size() + std::distance(op_vars.begin(), op_vars.end())); @@ -99,7 +103,9 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( const ProgramDesc &program) const { std::vector recv_vars; for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "recv" || op->Type() == "send") { + // TODO(Yancey1989): use a graceful method to find recv op, + // instead of the hard code string + if (op->Type() == "recv") { auto op_vars = op->OutputArgumentNames(); recv_vars.reserve(recv_vars.size() + std::distance(op_vars.begin(), op_vars.end())); @@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( auto checker = [](const std::vector &opvars, const std::vector &rpc_vars) -> bool { for (auto &var : opvars) { + // a variable name with the suffix `.block` means it's a splited + // variable by (DistributeTranspiler) + // [python/paddle/fluid/transpiler/distribute_transpiler.py] if (var.find(".block") != std::string::npos && std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; @@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return false; }; - if (op.Type() == "split" || op.Type() == "split_byref" || - op.Type() == "split_selected_rows") { - return checker(op.OutputArgumentNames(), send_vars); - } else if (op.Type() == "concat") { - return checker(op.InputArgumentNames(), recv_vars); - } - + return checker(op.OutputArgumentNames(), send_vars) || + checker(op.InputArgumentNames(), recv_vars); return false; } diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index f2385abed59a809e7fbea3569245af693f5842ad..51f0d2a7427a3923f038b3d85057fa0e5c8cf6a8 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); + const auto ch = GetChannel(ep_val); framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { @@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); + const auto ch = GetChannel(ep_val); framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val); + const auto ch = GetChannel(ep_val); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, } void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep, ep); + const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); @@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { } void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep, ep); + const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -248,10 +248,9 @@ bool RPCClient::Proceed() { delete c; return true; } -std::shared_ptr RPCClient::GetChannel(const std::string& ep, - const std::string& key) { +std::shared_ptr RPCClient::GetChannel(const std::string& ep) { std::unique_lock lock(mutex_); - auto it = channels_.find(key); + auto it = channels_.find(ep); if (it != channels_.end()) { return it->second; } @@ -263,7 +262,7 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep, auto ch = grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - channels_[key] = ch; + channels_[ep] = ch; return ch; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 6f8b67be3e0f4bb506840d67b6717e3a2f861ac7..e5007b509a30a5251e78e8636d53d81022dae0d3 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -191,8 +191,7 @@ class RPCClient { private: bool Proceed(); - std::shared_ptr GetChannel(const std::string& ep, - const std::string& key); + std::shared_ptr GetChannel(const std::string& ep); private: grpc::CompletionQueue cq_;