diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index b69de2ced03569d5e9ffe313527ab776ee798496..1bcd8412eb2d618b923bcd0557d118af62271f4a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) @@ -26,7 +26,7 @@ endif() cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 35d23d68c0dd26a05544a72316d5764129aa8d40..d8e711994c5dba15ce0a1c237558b121888902e3 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" -#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -28,6 +29,10 @@ #include #include +DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", + "the ssa graph path only print with GLOG_v=10," + "default /tmp/graph.dot"); + namespace paddle { namespace framework { namespace details { @@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, } } -bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, - OpDesc *send_op) const { - if (send_op == nullptr) { +std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( + const ProgramDesc &program) const { + std::vector send_vars; + // since parameters are all in block 0, + // it's enough to only scan send ops in block 0 + for (auto *op : program.Block(0).AllOps()) { + // TODO(Yancey1989): use a graceful method to find send op, + // instead of the the hard code string + if (op->Type() == "send_vars") { + auto op_vars = op->InputArgumentNames(); + send_vars.reserve(send_vars.size() + + std::distance(op_vars.begin(), op_vars.end())); + send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); + } + } + return send_vars; +} + +std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( + const ProgramDesc &program) const { + std::vector recv_vars; + for (auto *op : program.Block(0).AllOps()) { + // TODO(Yancey1989): use a graceful method to find recv op, + // instead of the hard code string + if (op->Type() == "recv") { + auto op_vars = op->OutputArgumentNames(); + recv_vars.reserve(recv_vars.size() + + std::distance(op_vars.begin(), op_vars.end())); + recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); + } + } + return recv_vars; +} + +bool MultiDevSSAGraphBuilder::IsDistTrainOp( + const OpDesc &op, const std::vector &send_vars, + const std::vector &recv_vars) const { + if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; } @@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, * Check any of opvars contains `.block` and in sendvars */ auto checker = [](const std::vector &opvars, - const std::vector &sendvars) -> bool { + const std::vector &rpc_vars) -> bool { for (auto &var : opvars) { + // a variable name with the suffix `.block` means it's a splited + // variable by (DistributeTranspiler) + // [python/paddle/fluid/transpiler/distribute_transpiler.py] if (var.find(".block") != std::string::npos && - std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { + std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; } } return false; }; - if (op.Type() == "split" || op.Type() == "split_byref") { - return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); - } else if (op.Type() == "concat") { - return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); - } - return false; + return checker(op.OutputArgumentNames(), send_vars) || + checker(op.InputArgumentNames(), recv_vars); } std::unique_ptr MultiDevSSAGraphBuilder::Build( @@ -123,8 +162,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); - // Find "send" op first for split is in front of send. - OpDesc *send_op = GetSendOpDesc(program); + // find send/recv vars so that we can place the distributed training + // realted op in the place 0 + auto send_vars = FindDistTrainSendVars(program); + auto recv_vars = FindDistTrainRecvVars(program); size_t cur_device_id = 0; std::vector> var_name_on_devices; @@ -134,12 +175,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - // append send op if program is distributed trainer main program. + if (boost::get( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kRPC)) { + // append rpc op if program is distributed trainer main program. // always use the first device - CreateSendOp(&result, *op); - } else if (IsDistTrainOp(*op, send_op)) { - CreateComputationalOps(&result, *op, 1); + CreateRPCOp(&result, *op); + } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { + CreateDistTrainOp(&result, *op); } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -218,9 +261,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { - std::ostringstream sout; - PrintGraphviz(*graph, sout); - VLOG(10) << sout.str(); + std::ofstream fout(FLAGS_ssa_graph_path); + PrintGraphviz(*graph, fout); } return std::unique_ptr(graph); @@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, CreateOpHandleIOs(result, op, dev_id); } -OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( - const ProgramDesc &program) const { - for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - return op; - } - } - return nullptr; -} void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA @@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, return var; } -void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, + const std::string &prev_op_name) const { + for (auto &prev_op : result->ops_) { + if (prev_op->Name() == prev_op_name) { + auto *dep_var = new DummyVarHandle(); + prev_op->AddOutput(dep_var); + result->dep_vars_.emplace(dep_var); + op->AddInput(dep_var); + } + } +} + +void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, + const OpDesc &op) const { + CreateComputationalOp(result, op, 0); + if (op.Type() == "concat") { + ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); + } +} + +void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, + const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; - // FIXME(wuyi): send op always copy from GPU 0 - result->ops_.emplace_back(new SendOpHandle(op, s, p)); - // Create inputs for output on original place and no ssa output - // is created for send op. + result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); + + if (op.Type() == "send_barrier") { + ConnectOp(result, result->ops_.back().get(), "send_vars"); + } else if (op.Type() == "recv") { + ConnectOp(result, result->ops_.back().get(), "send_barrier"); + } else if (op.Type() == "fetch_barrier") { + ConnectOp(result, result->ops_.back().get(), "recv"); + } else if (op.Type() == "send_vars") { + // do nothing + } else { + PADDLE_THROW( + "rpc op should be in [" + "send_vars, send_barrier. recv, fetch_barrier]"); + } + + // TODO(Yancey1989): schedule rpc op on different place may + // increate throughput CreateOpHandleIOs(result, op, 0); } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 4f708521884247fc013f0ae336ab683c3fe7ef2f..e07597dbd80889c366babe79455beb12c9eb80d9 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateSendOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; + void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. */ - bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + bool IsDistTrainOp(const OpDesc &op, + const std::vector &send_vars, + const std::vector &recv_vars) const; + + std::vector FindDistTrainSendVars( + const ProgramDesc &program) const; + + std::vector FindDistTrainRecvVars( + const ProgramDesc &program) const; + + void ConnectOp(SSAGraph *result, OpHandleBase *op, + const std::string &prev_op_name) const; void CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const; @@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; - /** - * Get send op in the global block of program. - * nullptr if not found. - */ - OpDesc *GetSendOpDesc(const ProgramDesc &program) const; - bool IsSparseGradient( const std::unordered_map &var_types, const std::string &og) const; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc similarity index 75% rename from paddle/fluid/framework/details/send_op_handle.cc rename to paddle/fluid/framework/details/rpc_op_handle.cc index 7109659dd7001f91e7674ac7bebbe3a59794cfc0..7f4da4c01de1010467d839ee5490c5e0d02d8c24 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -12,24 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/send_op_handle.h" +#include "paddle/fluid/framework/details/rpc_op_handle.h" namespace paddle { namespace framework { namespace details { -SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, - const Scope *local_scope, - const platform::Place &place) +RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, + const Scope *local_scope, const platform::Place &place, + const std::string &name) : op_(framework::OpRegistry::CreateOp(op_desc)), local_scope_(local_scope), - place_(place) {} + place_(place), + name_(name) {} -void SendOpHandle::RunImpl() { +void RPCOpHandle::RunImpl() { // TODO(wuyi): need further analysis whether wait VarDummyHandle. // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; + // FIXME(Yancey1989): need a better solution instead of use DebugString() if (in->DebugString() == "dummy") { // HACK continue; } @@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() { op_->Run(*tmp_scope, place_); } -std::string SendOpHandle::Name() const { return "send"; } +std::string RPCOpHandle::Name() const { return name_; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h similarity index 87% rename from paddle/fluid/framework/details/send_op_handle.h rename to paddle/fluid/framework/details/rpc_op_handle.h index 2f78811fad50642b5e45776c41910df6f4cc48f6..d28b7721720d808a8d81701c3811eae16121fb41 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -27,9 +27,9 @@ namespace paddle { namespace framework { namespace details { -struct SendOpHandle : public OpHandleBase { - SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const platform::Place& place); +struct RPCOpHandle : public OpHandleBase { + RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place, const std::string& name); std::string Name() const override; @@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase { std::unique_ptr op_; const Scope* local_scope_; const platform::Place& place_; + const std::string name_; }; } // namespace details diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 5a4380a83a2e5bf492098032cd9de7bf274fe47e..ae9f4efd44acdcdff2806deea6826e4089459a78 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, .InEnum( {static_cast(OpRole::kForward), static_cast(OpRole::kBackward), - static_cast(OpRole::kOptimize), + static_cast(OpRole::kOptimize), static_cast(OpRole::kRPC), static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 9bd6ca6ea32734707a5c37b3ecfe449436c04c8c..8493b9d8b326c71a33b95bf95e5fc1743c686eb7 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -24,6 +24,7 @@ enum class OpRole { kForward = 0x0000, kBackward = 0x0001, kOptimize = 0x0002, + kRPC = 0x0003, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc index 51d38d6251d853fa8a02a4e22f819cfc44294453..9d7cceeb65888b8ba3fdf39e88fc2877abd82d11 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc @@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) { GraphTraits trait(&dfg); auto nodes = trait.nodes(); - int count = 0; + size_t count = 0; for (auto it = nodes.begin(); it != nodes.end(); ++it) { LOG(INFO) << "visiting " << it->name(); ++count; @@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) { dfg.Build(); GraphTraits trait(&dfg); auto nodes = trait.nodes_in_DFS(); - int count = 0; + size_t count = 0; for (auto it = nodes.begin(); it != nodes.end(); ++it) { LOG(INFO) << "visiting " << it->name(); ++count; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f72997ca24ed837f761b52cbecdc05998424a675..e00cc73565fc98615090367606b6ba4f58feacfd 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE) op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) + op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op # listen_and_serv_op sum_op executor SERIAL) @@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) endif() else() - set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) + set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op) endif() op_library(cross_entropy_op DEPS cross_entropy) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 47892b1bcc073d24ea617ea1c680138a88925177..f7ce7786874285795878b655365974f082c00b44 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,6 +25,21 @@ namespace paddle { namespace operators { namespace detail { +std::once_flag RPCClient::init_flag_; + +std::unique_ptr RPCClient::rpc_client_(nullptr); + +RPCClient* RPCClient::GetInstance() { + std::call_once(init_flag_, &RPCClient::Init); + return rpc_client_.get(); +} + +void RPCClient::Init() { + if (rpc_client_.get() == nullptr) { + rpc_client_.reset(new RPCClient()); + } +} + bool RPCClient::AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, @@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, call->StartCall(); call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); }); - req_count_++; return true; @@ -249,8 +263,9 @@ bool RPCClient::Proceed() { delete c; return true; } - std::shared_ptr RPCClient::GetChannel(const std::string& ep) { + // TODO(Yancey1989): make grpc client completely thread-safe + std::unique_lock lock(mutex_); auto it = channels_.find(ep); if (it != channels_.end()) { return it->second; @@ -263,7 +278,6 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { auto ch = grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); - channels_[ep] = ch; return ch; } diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index dabce7414d2f0dca74193f1cd10c341793c10ec9..449d5105afb8c02294a0ef57610e7de1b1631b35 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include #include +#include // NOLINT #include #include @@ -35,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { namespace operators { @@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor { class RPCClient { public: + RPCClient() {} + + static RPCClient* GetInstance(); + bool AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, @@ -191,11 +197,17 @@ class RPCClient { private: bool Proceed(); std::shared_ptr GetChannel(const std::string& ep); + // Init is called by GetInstance. + static void Init(); private: grpc::CompletionQueue cq_; std::map> channels_; - int64_t req_count_ = 0; + std::atomic req_count_{0}; + std::mutex mutex_; + static std::unique_ptr rpc_client_; + static std::once_flag init_flag_; + DISABLE_COPY_AND_ASSIGN(RPCClient); }; } // namespace detail diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 73e75c9087fef756840c76db249f8996253ced64..350a7ee1234da5b88d09ea955ce14b7c161d804e 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - detail::RPCClient client; - client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, - out_var_name); - client.Wait(); + auto client = detail::RPCClient::GetInstance(); + client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, + out_var_name); + client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..79ec02f52094121d01c6bda2a5d99d2211893e89 --- /dev/null +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { + +class FetchBarrierOp : public framework::OperatorBase { + public: + FetchBarrierOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + std::vector eps = Attr>("endpoints"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + + auto rpc_client = detail::RPCClient::GetInstance(); + + PADDLE_ENFORCE(rpc_client->Wait()); + + for (auto& ep : eps) { + VLOG(3) << "fetch barrier, ep: " << ep; + rpc_client->AsyncSendFetchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); + } +}; + +class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddComment(R"DOC( +SendBarrier operator + +This operator will send a send barrier signal to list_and_serv op, so that +the Parameter Server would knew all variables have been sent. +)DOC"); + + AddAttr>("endpoints", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints to send variables to.") + .SetDefault({"127.0.0.1:6164"}); + } +}; + +class FetchBarrierOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp, + paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker, + ops::FetchBarrierOpShapeInference); diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 4cfea958e8e50156c90af8806414b043e15f8a9c..e0a9b24ac8978418a1a4ece62286e022bec8b834 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which will be" - "initialized at most once."); AddOutput("Out", "(LoDTensor) result " "to be fetched from parameter server") @@ -87,17 +79,6 @@ the parameter server and fetch result back. } }; -class PrefetchOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class PrefetchOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -110,5 +91,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(prefetch, ops::PrefetchOp, paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, - ops::PrefetchOpVarTypeInference, ops::PrefetchOpShapeInference); diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 7148bd0e363a71b58581a6c3c5f245d98d5b9d02..d8ddb7b448910b5e0e6e71742eb2fdc6a225c919 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase { const platform::Place& place) const override { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); + int sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i]; - client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; + rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + } + if (sync_mode) { + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(client_.Wait()); } - - private: - mutable detail::RPCClient client_; }; class RecvOpMaker : public framework::OpProtoAndCheckerMaker { @@ -65,6 +70,10 @@ This operator can get variables from server side. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); + AddAttr("sync_mode", + "(int, default 0)" + "sync recv or async recv.") + .SetDefault(0); } }; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 1ce0907f3a9473e37f53bf7b2d42cddcb629dfa6..2c77ee2e2792d6fdd76bacd68b6c3b4a296b2e3a 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + bool sync_mode = Attr("sync_mode"); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + + auto rpc_client = detail::RPCClient::GetInstance(); // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); - - for (auto& ep : eps) { - VLOG(3) << "send barrier, ep: " << ep; - rpc_client->AsyncSendBatchBarrier(ep); + if (sync_mode) { + for (auto& ep : eps) { + VLOG(3) << "send barrier, ep: " << ep; + rpc_client->AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); } }; class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( SendBarrier operator @@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent. "(string vector, default 127.0.0.1:6164)" "Server endpoints to send variables to.") .SetDefault({"127.0.0.1:6164"}); - } -}; - -class SendBarrierOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); + AddAttr("sync_mode", "work in sync_mode or not").SetDefault(true); } }; @@ -98,5 +88,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp, paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker, - ops::SendBarrierOpVarTypeInference, ops::SendBarrierOpShapeInference); diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 95bb1f3c695297e6d8134a647925310207118a9b..a5150f242ca3b0befafa2443f0bc466e2aea85e4 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddOutput("Out", "(Tensor) Output tensor to be received from server") .AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( Send operator @@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server. } }; -class SendOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, - ops::SendOpMaker, ops::SendOpVarTypeInference, - ops::SendOpShapeInference); + ops::SendOpMaker, ops::SendOpShapeInference); diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index d5303eaf50722234d205264e56892b1723104d53..e550552b195b768d68ec64e9c3b5889b56ca719f 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false, &initialized); while (!initialized) { } + static_cast(listen_and_serv_op.get()) ->WaitServerReady(); @@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) { std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); attrs.insert({"endpoints", std::vector({endpoint})}); attrs.insert({"epmap", std::vector({endpoint})}); - auto send_op = f::OpRegistry::CreateOp( - "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); + const f::VariableNameMap &inputs = {{"X", {"x1"}}}; + const f::VariableNameMap &outputs = {{"Out", {"Out"}}}; + + auto send_op = f::OpRegistry::CreateOp("send", inputs, outputs, attrs); send_op->Run(scope, place); auto in_var = scope.Var("x1"); @@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) { std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); attrs.insert({"endpoints", std::vector({endpoint})}); attrs.insert({"epmap", std::vector({endpoint})}); - auto send_op = f::OpRegistry::CreateOp( - "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); + auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}}, attrs); send_op->Run(scope, place); auto x0 = scope.Var("x0")->GetMutable(); diff --git a/paddle/fluid/operators/send_recv_util.h b/paddle/fluid/operators/send_recv_util.h index 113513eb6b327773ab4a1c062fb8a3f06fddfbca..deab005149027caffa962783df944fad7110382f 100644 --- a/paddle/fluid/operators/send_recv_util.h +++ b/paddle/fluid/operators/send_recv_util.h @@ -20,6 +20,9 @@ namespace operators { inline bool NeedSend(const framework::Scope& scope, const std::string& varname) { + // dummy variable is only used in parallel executor to represent + // some dependency relationship, we don't need to send/recv it. + if (varname == "dummy") return false; auto* var = scope.FindVar(varname); PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", varname); diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index f11e84c176ae97dff0fda560ce3ebe2ab72c7bcc..fe839dab6924618c8a4c39868d9bf86056a0be40 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + // For profiling + platform::RecordEvent record_event(Type(), &ctx); + + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { void Make() { AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") .AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which will be" - "initialized at most once."); AddComment(R"DOC( Send operator @@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server } }; -class SendVarsOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendVarsOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -112,5 +97,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(send_vars, ops::SendVarsOp, paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker, - ops::SendVarsOpVarTypeInference, ops::SendVarsOpShapeInference); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 9111abca5aac97e9d5c7b00ce5173f08e49cda12..76aa7d2010682416f68e982e9b89da9813abb078 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) { .value("Forward", framework::OpRole::kForward) .value("Backward", framework::OpRole::kBackward) .value("Optimize", framework::OpRole::kOptimize) - .value("Loss", framework::OpRole::kLoss); + .value("Loss", framework::OpRole::kLoss) + .value("RPC", framework::OpRole::kRPC); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 03d4602f7a99dc335260cffdcdc30a839f3988cd..8758ac9f94ab91b5be5fc70917c64db38997d1c1 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None): endpoints = list(set(epmap)) helper = LayerHelper("Send", **locals()) - rpc_client_var = default_main_program().global_block().create_var( - name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW) if not get_vars: get_vars = [] for s in send_vars: v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True) get_vars.append(v) + rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName() helper.append_op( type="send", inputs={"X": send_vars}, - outputs={"Out": get_vars, - "RPCClient": rpc_client_var}, - attrs={"endpoints": endpoints, - "epmap": epmap}) + outputs={"Out": get_vars}, + attrs={ + "endpoints": endpoints, + "epmap": epmap, + rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC + }) + return get_vars diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 10f8c4f3f0167632bb4a3d454ab026ba73a8f305..fa49bd41a5876847d046682dce5c3d3868a18500 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase): def test_transpiler(self): trainer = self.get_trainer() pserver, startup = self.get_pserver(self.current_pserver_ep) - self.assertEqual([op.type for op in trainer.global_block().ops], self.get_expect_trainer_ops()) @@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase): "fill_constant", "fill_constant", "uniform_random", "uniform_random" ]) - # the variable #fc_w will be split into two blocks + # the variable #fc_w will be split into two blocks fc_w_var = startup.global_block().var("fc_w.block1") self.assertEqual(fc_w_var.shape, (500, 1000)) @@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase): optimize_ops, params_grads = self.net_conf() delete_ops(trainer.global_block(), optimize_ops) - return [op.type for op in trainer.global_block().ops - ] + ["split_byref", "send", "concat"] + ops = [op.type for op in trainer.global_block().ops] + [ + "split_byref", "send_vars", "send_barrier", "recv", "recv", + "fetch_barrier", "concat" + ] + ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars") + return ops def get_trainer(self): return self._transpiler_instance().get_trainer_program() diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py index 413c36c5c41bbe0169f1c050ccdac040202d66df..045ca537b2e84c02298d6375a7ef5bdbb5517380 100644 --- a/python/paddle/fluid/transpiler/__init__.py +++ b/python/paddle/fluid/transpiler/__init__.py @@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler from inference_transpiler import InferenceTranspiler from memory_optimization_transpiler import memory_optimize, release_memory from distribute_transpiler_simple import SimpleDistributeTranspiler +from ps_dispatcher import HashName, RoundRobin __all__ = [ "DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler", - "memory_optimize", "release_memory" + "memory_optimize", "release_memory", "HashName", "RoundRobin" ] diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 42ff0a9eb1112ed5709749e3867794c80be8f1d1..4e17fdb16b6c2eb9846fd27ccde36e532d600a7e 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -16,7 +16,7 @@ from __future__ import print_function import math -import distributed_splitter as splitter +from ps_dispatcher import RoundRobin, HashName, PSDispatcher from .. import core, framework from ..framework import Program, default_main_program, \ default_startup_program, \ @@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \ LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" -RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" +RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( +) +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC class VarBlock: @@ -149,13 +151,27 @@ def delete_ops(block, ops): block.program.sync_with_cpp() +def find_op_by_input_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.input_arg_names: + return index + return -1 + + +def find_op_by_output_arg(block, arg_name): + for index, op in enumerate(block.ops): + if arg_name in op.output_arg_names: + return index + return -1 + + class DistributeTranspiler: def transpile(self, trainer_id, program=None, pservers="127.0.0.1:6174", trainers=1, - split_method=splitter.round_robin, + split_method=RoundRobin, sync_mode=True): """ Transpile the program to distributed data-parallelism programs. @@ -196,7 +212,7 @@ class DistributeTranspiler: :param sync_mode: if sync_mode is set True, it means that dist transpiler will transpile the program into sync_mode pserver and trainer program. """ - assert (callable(split_method)) + assert (split_method.__bases__[0] == PSDispatcher) if program is None: program = default_main_program() self.origin_program = program @@ -209,6 +225,7 @@ class DistributeTranspiler: pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints self.optimize_ops, params_grads = self._get_optimize_pass() + ps_dispatcher = split_method(pserver_endpoints) # process lookup_table_op # 1. check all lookup_table_op is distributed @@ -268,54 +285,110 @@ class DistributeTranspiler: grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) + assert (len(grad_blocks) == len(param_blocks)) # step2: Create new vars for the parameters and gradients blocks and # add ops to do the split. - grad_var_mapping = self._append_split_op(program, grad_blocks) param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) + grad_var_mapping = self._create_vars_from_blocklist( + program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) + grad_param_mapping = dict() + for g, p in zip(grad_blocks, param_blocks): + g_name, g_bid, _ = g.split(":") + p_name, p_bid, _ = p.split(":") + grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ + param_var_mapping[p_name][int(p_bid)] + + # step 3: transpile trainer side program, insert recv op and send op. - # step3: Add gradients as send op inputs and parameters as send - # op outputs. - send_inputs = [] - send_outputs = [] - for b in grad_blocks: # append by order - varname, block_id, _ = b.split(":") - send_inputs.append(grad_var_mapping[varname][int(block_id)]) - - for b in param_blocks: - varname, block_id, _ = b.split(":") - send_outputs.append(param_var_mapping[varname][int(block_id)]) - - # let send_op know which endpoint to send which var to, eplist has the same - # order as send_inputs. - eplist = split_method(send_inputs, pserver_endpoints) # create mapping of endpoint -> split var to create pserver side program self.param_grad_ep_mapping = dict() + [ + self.param_grad_ep_mapping.update({ + ep: { + "params": [], + "grads": [] + } + }) for ep in self.pserver_endpoints + ] + + # step 3.1: insert send op to send gradient vars to parameter servers + ps_dispatcher.reset() + send_vars = [] + for orig_varname, splited_vars in grad_var_mapping.items(): + eplist = ps_dispatcher.dispatch(splited_vars) + if len(splited_vars) == 1: + orig_varname = splited_vars[0].name + index = find_op_by_output_arg(program.global_block(), + orig_varname) + elif len(splited_vars) > 1: + orig_var = program.global_block().vars[orig_varname] + index = find_op_by_output_arg(program.global_block(), + orig_varname) + self._insert_split_op(program, orig_var, index, splited_vars) + index += 1 + else: + AssertionError("Can not insert the send op by original " + "variable name :", orig_varname) + + program.global_block().insert_op( + index=index + 1, + type="send_vars", + inputs={"X": splited_vars}, + outputs={}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + for _, var in enumerate(splited_vars): + send_vars.append(var) + + if self.sync_mode: + program.global_block().append_op( + type="send_barrier", + inputs={}, + outputs={}, + attrs={ + "endpoints": pserver_endpoints, + "sync_mode": self.sync_mode, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + # step 3.2: insert recv op to receive parameters from parameter server + recv_vars = [] + for _, var in enumerate(send_vars): + recv_vars.append(grad_param_mapping[var]) + ps_dispatcher.reset() + eplist = ps_dispatcher.dispatch(recv_vars) + for i, ep in enumerate(eplist): - param = send_outputs[i] - grad = send_inputs[i] - if not self.param_grad_ep_mapping.has_key(ep): - self.param_grad_ep_mapping[ep] = {"params": [], "grads": []} - self.param_grad_ep_mapping[ep]["params"].append(param) - self.param_grad_ep_mapping[ep]["grads"].append(grad) - - rpc_client_var = program.global_block().create_var( - name=RPC_CLIENT_VAR_NAME, - persistable=True, - type=core.VarDesc.VarType.RAW) - - # create send_op + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # step4: Concat the parameters splits together after recv. + for varname, splited_var in param_var_mapping.iteritems(): + eps = [] + for var in splited_var: + index = [v.name for v in recv_vars].index(var.name) + eps.append(eplist[index]) + + program.global_block().append_op( + type="recv", + inputs={}, + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + program.global_block().append_op( - type="send", - inputs={"X": send_inputs}, - outputs={"Out": send_outputs, - "RPCClient": rpc_client_var}, + type="fetch_barrier", + inputs={}, + outputs={}, attrs={ "endpoints": pserver_endpoints, - "epmap": eplist, - "sync_mode": self.sync_mode + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) - # step4: Concat the parameters splits together after recv. + for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: continue @@ -327,10 +400,8 @@ class DistributeTranspiler: attrs={"axis": 0}) if self.has_distributed_lookup_table: - self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, - eplist) - self._split_table_grad_and_add_send_vars(program, rpc_client_var, - pserver_endpoints) + self._replace_lookup_table_op_with_prefetch(program, eplist) + self._split_table_grad_and_add_send_vars(program, pserver_endpoints) def get_trainer_program(self): # remove optimize ops and add a send op to main_program @@ -550,8 +621,7 @@ class DistributeTranspiler: return s_prog # transpiler function for dis lookup_table - def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, - eplist): + def _replace_lookup_table_op_with_prefetch(self, program, eplist): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op self.prefetch_input_vars = None self.prefetch_output_vars = None @@ -598,11 +668,11 @@ class DistributeTranspiler: index=op_index + 1, type="prefetch", inputs={'X': self.prefetch_input_vars}, - outputs={ - "Out": self.prefetch_output_vars, - "RPCClient": rpc_client_var - }, - attrs={"epmap": eplist}) + outputs={"Out": self.prefetch_output_vars}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) # insert concat_op program.global_block().insert_op( @@ -622,8 +692,7 @@ class DistributeTranspiler: # break for loop break - def _split_table_grad_and_add_send_vars(self, program, rpc_client_var, - pserver_endpoints): + def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_vars_op to send gradient to pservers # there should only be one table_name all_ops = program.global_block().ops @@ -643,9 +712,12 @@ class DistributeTranspiler: index=op_index + 2, type="send_vars", inputs={'X': self.table_grad_list}, - outputs={"RPCClient": rpc_client_var}, - attrs={"sync_send": True, - "epmap": pserver_endpoints}) + outputs={}, + attrs={ + "sync_send": True, + "epmap": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) break def _create_prefetch_block(self, pserver_index, pserver_program, @@ -838,50 +910,31 @@ class DistributeTranspiler: lod_level=var.lod_level, persistable=persistable) - def _append_split_op(self, program, gradblocks): - """ - Split variables that need to be split and append respective ops - Args: - program (ProgramDesc): ProgramDesc that gradients blong. - gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks. - Returns: - var_mapping (dict(varname->[new_splitted_variable])):A dict mapping - from original var name to each var split. - """ - - add_suffix = False - if self.trainer_num > 1: - add_suffix = True - var_mapping = self._create_vars_from_blocklist( - program, gradblocks, add_trainer_suffix=add_suffix) - for varname, splited_vars in var_mapping.iteritems(): - # variable that don't need to split have empty splited_vars - if len(splited_vars) <= 1: - continue - orig_var = program.global_block().vars[varname] - if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: - height_sections = [] - for v in splited_vars: - height_sections.append(v.shape[0]) - program.global_block().append_op( - type="split_selected_rows", - inputs={"X": orig_var}, - outputs={"Out": splited_vars}, - attrs={"height_sections": height_sections}) - elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: - sections = [] - for v in splited_vars: - sections.append(v.shape[0]) - program.global_block().append_op( - type="split_byref", - inputs={"X": orig_var}, - outputs={"Out": splited_vars}, - attrs={"sections": sections} # assume split evenly - ) - else: - AssertionError("Variable type should be in set " - "[LOD_TENSOR, SELECTED_ROWS]") - return var_mapping + def _insert_split_op(self, program, orig_var, index, splited_vars): + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: + height_sections = [] + for v in splited_vars: + height_sections.append(v.shape[0]) + program.global_block().insert_op( + index=index + 1, + type="split_selected_rows", + inputs={"X": orig_var}, + outputs={"Out": splited_vars}, + attrs={"height_sections": height_sections}) + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: + sections = [] + for v in splited_vars: + sections.append(v.shape[0]) + program.global_block().insert_op( + index=index + 1, + type="split_byref", + inputs={"X": orig_var}, + outputs={"Out": splited_vars}, + attrs={"sections": sections} # assume split evenly + ) + else: + AssertionError("Variable type should be in set " + "[LOD_TENSOR, SELECTED_ROWS]") def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, param_shape): diff --git a/python/paddle/fluid/transpiler/distributed_splitter.py b/python/paddle/fluid/transpiler/distributed_splitter.py deleted file mode 100644 index 060c1df8ad2badc5132f45ff0f44d136d828faa1..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/transpiler/distributed_splitter.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def hash_name(varlist, pserver_endpoints): - """ - hash variable names to several endpoints. - - Args: - varlist(list): a list of Variables - - Returns(dict): a map of pserver endpoint -> varname - """ - - def _hash_block(block_str, total): - return hash(block_str) % total - - eplist = [] - for var in varlist: - server_id = _hash_block(var.name(), len(pserver_endpoints)) - server_for_param = pserver_endpoints[server_id] - eplist.append(server_for_param) - return eplist - - -def round_robin(varlist, pserver_endpoints): - """ - Distribute variables to several endpoints. - Args: - varlist(list): a list of variables - pserver_endpoints(list): a list of pserver endpoints - - Returns(list[int]): the endpoint for each variable - """ - assert (len(varlist) >= len(pserver_endpoints)) - - eplist = [] - pserver_idx = 0 - for var in varlist: - server_for_param = pserver_endpoints[pserver_idx] - eplist.append(server_for_param) - - pserver_idx += 1 - if pserver_idx >= len(pserver_endpoints): - pserver_idx = 0 - return eplist diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a68677527deb09ace0e3a23cbc093d6d7b4349 --- /dev/null +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -0,0 +1,78 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PSDispatcher(object): + """ + PSDispatcher is the base class for dispatching vars + into different pserver instance. + You need to implement the `dispatch` inferface. + """ + + def __init__(self, pserver_endpoints): + self._eps = pserver_endpoints + self._step = 0 + + @property + def eps(self): + return self._eps + + def reset(self): + self._step = 0 + + def dispatch(self, varlist): + """ + :param varlist: a list of Variables + :return: a map of pserver endpoint -> varname + """ + AssertionError("Interface has not been implemented.") + + +class HashName(PSDispatcher): + """ + Hash variable names to several endpoints + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def _hash_block(self, block_str, total): + return hash(block_str) % total + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_id = self._hash_block(var.name(), len(self._eps)) + server_for_param = self._eps[server_id] + eplist.append(server_for_param) + return eplist + + +class RoundRobin(PSDispatcher): + """ + Distribute variables to serveral endpoints. + """ + + def __init__(self, pserver_endpoints): + super(self.__class__, self).__init__(pserver_endpoints) + + def dispatch(self, varlist): + eplist = [] + for var in varlist: + server_for_param = self._eps[self._step] + eplist.append(server_for_param) + self._step += 1 + if self._step >= len(self._eps): + self._step = 0 + return eplist