diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9de44beafbb69b3510b97afcc43d4b489a029c35..2c838f4361422c1e088569bed987d1fd727a9dbc 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -4,6 +4,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) @@ -26,7 +27,7 @@ endif() cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 45bad58145a1144dfabdd3e15b38d972d57b105e..50998fb8e0c9ce940108278464a0b3fe83425bf6 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" @@ -77,7 +79,6 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, CreateOpOutput(result, op_handle, each_var_name, p, place_id); } } - bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const { if (send_op == nullptr) { @@ -98,7 +99,7 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; }; - if (op.Type() == "split") { + if (op.Type() == "split" || op.Type() == "split_byref") { return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); } else if (op.Type() == "concat") { return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); @@ -106,6 +107,15 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; } +bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { + for (auto &name : op.OutputNames()) { + if (name == "RPCClient") { + return true; + } + } + return false; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unordered_map var_types; @@ -133,10 +143,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - // append send op if program is distributed trainer main program. + if (IsRPCOp(*op)) { + // append rpc op if program is distributed trainer main program. // always use the first device - CreateSendOp(&result, *op); + CreateRPCOp(&result, *op); } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { @@ -203,9 +213,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ostringstream sout; - PrintGraphviz(*graph, sout); - VLOG(10) << sout.str(); + std::string filename = "/tmp/graph"; + std::ofstream fout(filename); + PrintGraphviz(*graph, fout); } return std::unique_ptr(graph); @@ -386,12 +396,40 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } -void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, + std::string op_name) const { + for (auto &prev_op : result->ops_) { + if (prev_op->Name() == op_name) { + auto *dep_var = new DummyVarHandle(); + prev_op->AddOutput(dep_var); + result->dep_vars_.emplace(dep_var); + result->ops_.back().get()->AddInput(dep_var); + } + } +} + +void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, + const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; + VLOG(3) << "create rpc op: " << op.Type(); + result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + if (op.Type() == "send_barrier") { + ConnectOp(result, "send_vars"); + } else if (op.Type() == "recv") { + ConnectOp(result, "send_barrier"); + } else if (op.Type() == "fetch_barrier") { + ConnectOp(result, "recv"); + } else if (op.Type() == "send" || op.Type() == "send_vars") { + // do nothing + } else { + PADDLE_THROW( + "rpc op should be in [send," + "send_vars, send_barrier. recv, fetch_barrier]"); + } + // FIXME(wuyi): send op always copy from GPU 0 - result->ops_.emplace_back(new SendOpHandle(op, s, p)); + // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); // Create inputs for output on original place and no ssa output // is created for send op. CreateOpHandleIOs(result, op, 0); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 4f708521884247fc013f0ae336ab683c3fe7ef2f..45713b0c4f67a34f5b97e294402d93d07cffca93 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,12 +65,17 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. */ bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + bool IsRPCOp(const OpDesc &op) const; + + void ConnectOp(SSAGraph *result, std::string op_name) const; + void CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const; diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..03f53421b1d3cbe3d455d39e657b723b724f70c0 --- /dev/null +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/rpc_op_handle.h" + +namespace paddle { +namespace framework { +namespace details { + +RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, + const Scope *local_scope, const platform::Place &place, + const std::string &name) + : op_(framework::OpRegistry::CreateOp(op_desc)), + local_scope_(local_scope), + place_(place), + name_(name) {} + +void RPCOpHandle::RunImpl() { + // TODO(wuyi): need further analysis whether wait VarDummyHandle. + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + if (in->DebugString() == "dummy") { // HACK + continue; + } + if (in->generated_op_) { + in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + } + } + auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); + // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead + // lock. + op_->Run(*tmp_scope, place_); +} + +std::string RPCOpHandle::Name() const { return name_; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..d28b7721720d808a8d81701c3811eae16121fb41 --- /dev/null +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace details { + +struct RPCOpHandle : public OpHandleBase { + RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place, const std::string& name); + + std::string Name() const override; + + // Delay and buffer nccl_all_reduce together can significantly increase + // performance. Disable this feature by returning false. + bool IsMultiDeviceTransfer() override { return false; }; + + protected: + void RunImpl() override; + + private: + std::unique_ptr op_; + const Scope* local_scope_; + const platform::Place& place_; + const std::string name_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 067e0c2b8389f88639fd9b95bd680702517efee1..387e06bca6e477dabbc073763344e560715738e4 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -14,6 +14,7 @@ #pragma once #include +#include // NOLINT #include #include #include @@ -38,6 +39,7 @@ class Variable { template T* GetMutable() { + std::unique_lock lock(mutex_); if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); } @@ -90,6 +92,7 @@ class Variable { // by its address but not the unreadable name. friend class Scope; const std::string* name_; + std::mutex mutex_; }; } // namespace framework diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ae60ab15325ef101feb7270a4f5d840cb2112be0..ca0518d4dc9554c2f77bc5ac890589d87f0f044a 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -33,7 +33,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { @@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { @@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, } void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); + const auto ch = GetChannel(ep, ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); @@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { } void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { - const auto ch = GetChannel(ep); + const auto ch = GetChannel(ep, ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -243,12 +243,19 @@ bool RPCClient::Proceed() { delete c; return true; } - -std::shared_ptr RPCClient::GetChannel(const std::string& ep) { - auto it = channels_.find(ep); +std::shared_ptr RPCClient::GetChannel(const std::string& ep, + const std::string& key) { + VLOG(3) << "this addr: " << this; + std::unique_lock lock(mutex_); + auto it = channels_.find(key); if (it != channels_.end()) { + VLOG(3) << "find ep: " << ep; return it->second; } + VLOG(3) << "can not find ep: " << ep; + for (auto it = channels_.begin(); it != channels_.end(); ++it) { + VLOG(3) << "ep: " << it->first; + } grpc::ChannelArguments args; args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); @@ -257,8 +264,7 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { auto ch = grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - - channels_[ep] = ch; + channels_[key] = ch; return ch; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index dabce7414d2f0dca74193f1cd10c341793c10ec9..4e1d608549fea5ef74779821d2547630592e41ae 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include #include +#include // NOLINT #include #include @@ -190,12 +191,14 @@ class RPCClient { private: bool Proceed(); - std::shared_ptr GetChannel(const std::string& ep); + std::shared_ptr GetChannel(const std::string& ep, + const std::string& key); private: grpc::CompletionQueue cq_; std::map> channels_; int64_t req_count_ = 0; + std::mutex mutex_; }; } // namespace detail diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 3dfdd135eeec74c0e19037bbff2d678ba9b78753..5d2e558699116c270dc603b0137dee7f16bee1f8 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,6 +38,11 @@ class FetchBarrierOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 3b5459f3e3c091df3112530399cc2d79264c2841..7ca3c20c7d2b7c00dd1ac66432f9b79dd666987b 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,15 +38,18 @@ class RecvOp : public framework::OperatorBase { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); auto client_var_name = Output("RPCClient"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); detail::RPCClient* rpc_client = client_var->GetMutable(); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 1ce0907f3a9473e37f53bf7b2d42cddcb629dfa6..05e262363095d0914d057ace353e32c6a6702413 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -37,6 +38,10 @@ class SendBarrierOp : public framework::OperatorBase { const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h index 113513eb6b327773ab4a1c062fb8a3f06fddfbca..deab005149027caffa962783df944fad7110382f 100644 --- a/paddle/fluid/operators/send_recv_util.h +++ b/paddle/fluid/operators/send_recv_util.h @@ -20,6 +20,9 @@ namespace operators { inline bool NeedSend(const framework::Scope& scope, const std::string& varname) { + // dummy variable is only used in parallel executor to represent + // some dependency relationship, we don't need to send/recv it. + if (varname == "dummy") return false; auto* var = scope.FindVar(varname); PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", varname); diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index f11e84c176ae97dff0fda560ce3ebe2ab72c7bcc..3caceba4e9c68912f05de66fe9139cad1aad6d3c 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -41,12 +42,17 @@ class SendVarsOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + auto client_var_name = Output("RPCClient"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), "Can not find variable '%s' in the scope.", client_var_name); auto* client_var = scope.FindVar(client_var_name); + VLOG(3) << "client var addr: " << client_var; detail::RPCClient* rpc_client = client_var->GetMutable(); + VLOG(3) << "rpc_client addr: " << rpc_client; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) {