From 62af10d440473867c42cd9a8e2e5a3b8d854d500 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 21 May 2018 18:52:32 +0800 Subject: [PATCH] support multiple devices --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../details/multi_devices_graph_builder.cc | 60 +++++++++++++++---- .../details/multi_devices_graph_builder.h | 5 ++ .../fluid/framework/details/rpc_op_handle.cc | 50 ++++++++++++++++ .../fluid/framework/details/rpc_op_handle.h | 52 ++++++++++++++++ paddle/fluid/framework/variable.h | 3 + paddle/fluid/operators/detail/grpc_client.cc | 26 ++++---- paddle/fluid/operators/detail/grpc_client.h | 5 +- paddle/fluid/operators/fetch_barrier_op.cc | 6 ++ paddle/fluid/operators/recv_op.cc | 10 +++- paddle/fluid/operators/send_barrier_op.cc | 5 ++ paddle/fluid/operators/send_recv_util.h | 3 + paddle/fluid/operators/send_vars_op.cc | 6 ++ 13 files changed, 208 insertions(+), 26 deletions(-) create mode 100644 paddle/fluid/framework/details/rpc_op_handle.cc create mode 100644 paddle/fluid/framework/details/rpc_op_handle.h diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9de44beafbb..2c838f43614 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) @@ -26,7 +27,7 @@ endif() cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 45bad58145a..50998fb8e0c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" @@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, CreateOpOutput(result, op_handle, each_var_name, p, place_id); } } - bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const { if (send_op == nullptr) { @@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; }; - if (op.Type() == "split") { + if (op.Type() == "split" || op.Type() == "split_byref") { return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); } else if (op.Type() == "concat") { return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); @@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; } +bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { + for (auto &name : op.OutputNames()) { + if (name == "RPCClient") { + return true; + } + } + return false; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unordered_map var_types; @@ -133,10 +143,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - // append send op if program is distributed trainer main program. + if (IsRPCOp(*op)) { + // append rpc op if program is distributed trainer main program. // always use the first device - CreateSendOp(&result, *op); + CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { @@ -203,9 +213,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ostringstream sout; - PrintGraphviz(*graph, sout); - VLOG(10) << sout.str(); + std::string filename = "/tmp/graph"; + std::ofstream fout(filename); + PrintGraphviz(*graph, fout); } return std::unique_ptr(graph); @@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } -void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, + std::string op_name) const { + for (auto &prev_op : result->ops_) { + if (prev_op->Name() == op_name) { + auto *dep_var = new DummyVarHandle(); + prev_op->AddOutput(dep_var); + result->dep_vars_.emplace(dep_var); + result->ops_.back().get()->AddInput(dep_var); + } + } +} + +void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, + const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; + VLOG(3) << "create rpc op: " << op.Type(); + result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + if (op.Type() == "send_barrier") { + ConnectOp(result, "send_vars"); + } else if (op.Type() == "recv") { + ConnectOp(result, "send_barrier"); + } else if (op.Type() == "fetch_barrier") { + ConnectOp(result, "recv"); + } else if (op.Type() == "send" || op.Type() == "send_vars") { + // do nothing + } else { + PADDLE_THROW( + "rpc op should be in [send," + "send_vars, send_barrier. recv, fetch_barrier]"); + } + // FIXME(wuyi): send op always copy from GPU 0 - result->ops_.emplace_back(new SendOpHandle(op, s, p)); + // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); // Create inputs for output on original place and no ssa output // is created for send op. CreateOpHandleIOs(result, op, 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 4f708521884..45713b0c4f6 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. */ bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + bool IsRPCOp(const OpDesc &op) const; + + void ConnectOp(SSAGraph *result, std::string op_name) const; + void CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const; diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc new file mode 100644 index 00000000000..03f53421b1d --- /dev/null +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/rpc_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, + const Scope *local_scope, const platform::Place &place, + const std::string &name) + : op_(framework::OpRegistry::CreateOp(op_desc)), + local_scope_(local_scope), + place_(place), + name_(name) {} + +void RPCOpHandle::RunImpl() { + // TODO(wuyi): need further analysis whether wait VarDummyHandle. + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } + if (in->generated_op_) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + } + } + auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); + // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead + // lock. + op_->Run(*tmp_scope, place_); +} + +std::string RPCOpHandle::Name() const { return name_; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h new file mode 100644 index 00000000000..d28b7721720 --- /dev/null +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace details { + +struct RPCOpHandle : public OpHandleBase { + RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place, const std::string& name); + + std::string Name() const override; + + // Delay and buffer nccl_all_reduce together can significantly increase + // performance. Disable this feature by returning false. + bool IsMultiDeviceTransfer() override { return false; }; + + protected: + void RunImpl() override; + + private: + std::unique_ptr op_; + const Scope* local_scope_; + const platform::Place& place_; + const std::string name_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 067e0c2b838..387e06bca6e 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -14,6 +14,7 @@ #pragma once #include +#include // NOLINT #include #include #include @@ -38,6 +39,7 @@ class Variable { template T* GetMutable() { + std::unique_lock lock(mutex_); if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); } @@ -90,6 +92,7 @@ class Variable { // by its address but not the unreadable name. friend class Scope; const std::string* name_; + std::mutex mutex_; }; } // namespace framework diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ae60ab15325..ca0518d4dc9 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { @@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, } void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); + const auto ch = GetChannel(ep, ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); @@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { } void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); + const auto ch = GetChannel(ep, ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -243,12 +243,19 @@ bool RPCClient::Proceed() { delete c; return true; } - -std::shared_ptr RPCClient::GetChannel(const std::string& ep) { - auto it = channels_.find(ep); +std::shared_ptr RPCClient::GetChannel(const std::string& ep, + const std::string& key) { + VLOG(3) << "this addr: " << this; + std::unique_lock lock(mutex_); + auto it = channels_.find(key); if (it != channels_.end()) { + VLOG(3) << "find ep: " << ep; return it->second; } + VLOG(3) << "can not find ep: " << ep; + for (auto it = channels_.begin(); it != channels_.end(); ++it) { + VLOG(3) << "ep: " << it->first; + } grpc::ChannelArguments args; args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); @@ -257,8 +264,7 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { auto ch = grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - - channels_[ep] = ch; + channels_[key] = ch; return ch; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index dabce7414d2..4e1d608549f 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include #include +#include // NOLINT #include #include @@ -190,12 +191,14 @@ class RPCClient { private: bool Proceed(); - std::shared_ptr GetChannel(const std::string& ep); + std::shared_ptr GetChannel(const std::string& ep, + const std::string& key); private: grpc::CompletionQueue cq_; std::map> channels_; int64_t req_count_ = 0; + std::mutex mutex_; }; } // namespace detail diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 3dfdd135eee..5d2e5586991 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 3b5459f3e3c..7ca3c20c7d2 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); detail::RPCClient* rpc_client = client_var->GetMutable(); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 1ce0907f3a9..05e26236309 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h index 113513eb6b3..deab0051490 100644 --- a/paddle/fluid/operators/send_recv_util.h +++ b/paddle/fluid/operators/send_recv_util.h @@ -20,6 +20,9 @@ namespace operators { inline bool NeedSend(const framework::Scope& scope, const std::string& varname) { + // dummy variable is only used in parallel executor to represent + // some dependency relationship, we don't need to send/recv it. + if (varname == "dummy") return false; auto* var = scope.FindVar(varname); PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", varname); diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index f11e84c176a..3caceba4e9c 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); + VLOG(3) << "client var addr: " << client_var; detail::RPCClient* rpc_client = client_var->GetMutable(); + VLOG(3) << "rpc_client addr: " << rpc_client; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { -- GitLab