diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 44e7b16bef96619e10802481f5adf3a972429cba..1bdfec4c4cfbe4473037b841a5548208af2a85e8 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -62,8 +62,16 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("sections")); auto trainer_id = boost::get(node->Op()->GetNullableAttr("trainer_id")); + auto merge_add = + boost::get(node->Op()->GetNullableAttr("merge_add")); + if (!merge_add) { + merge_add = FLAGS_communicator_is_sgd_optimizer; + } + auto use_send_handler = + boost::get(node->Op()->GetNullableAttr("use_send_handler")); send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( - send_var_name, send_varnames, epmap, height_section, trainer_id); + send_var_name, send_varnames, epmap, height_section, trainer_id, + merge_add, use_send_handler); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; } else if (node->Name() == "recv") { diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 790c0cce4cfc83915f2d2a966d3af196012273c3..880290ee3947a005b0c681451251c6d6e29ca0d2 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -130,8 +130,15 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, auto height_section = boost::get>(op->GetNullableAttr("sections")); auto trainer_id = boost::get(op->GetNullableAttr("trainer_id")); + auto merge_add = boost::get(op->GetNullableAttr("merge_add")); + if (!merge_add) { + merge_add = FLAGS_communicator_is_sgd_optimizer; + } + auto use_send_handler = + boost::get(op->GetNullableAttr("use_send_handler")); send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( - send_var_name, send_varnames, epmap, height_section, trainer_id); + send_var_name, send_varnames, epmap, height_section, trainer_id, + merge_add, use_send_handler); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; } else if (op->Type() == "recv") { @@ -208,12 +215,17 @@ void AsyncCommunicator::SendThread() { } } auto before_merge = GetCurrentUS(); - MergeVars(var_name, vars, send_scope_.get()); + auto &ctx = send_varname_to_ctx_.at(var_name); + if (ctx.use_send_handler) { + MergeVars(var_name, vars, send_scope_.get(), ctx.merge_add); + } else { + MergeVars(var_name, vars, send_scope_.get(), + ctx.merge_add); + } auto after_merge = GetCurrentUS(); VLOG(3) << "merge " << merged_var_num << " " << var_name << " use time " << after_merge - before_merge; auto send_functor = distributed::ParameterSend(); - auto &ctx = send_varname_to_ctx_.at(var_name); if (!FLAGS_communicator_fake_rpc) { send_functor(ctx, *send_scope_, true, 1); } diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index be61a0281cd42f5a0e1f0738701f4d9c30932972..eb702bec9066f2bb1446d5dba9682fbda1ebcd0e 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -107,21 +107,21 @@ template using EigenVector = framework::EigenVector; +template inline void MergeVars(const std::string& var_name, const std::vector>& vars, - Scope* scope) { + Scope* scope, bool merge_add = true) { PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); auto cpu_place = platform::CPUPlace(); auto& var0 = vars[0]; auto* out_var = scope->Var(var_name); if (var0->IsType()) { auto dims = var0->Get().dims(); - VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims; - + VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims + << "; merge add: " << merge_add; // init output tensor auto* out_t = out_var->GetMutable(); - out_t->mutable_data(dims, cpu_place); - + out_t->mutable_data(dims, cpu_place); // check the input dims for (auto& var : vars) { auto& var_t = var->Get(); @@ -130,44 +130,41 @@ inline void MergeVars(const std::string& var_name, // set output tensor to 0. auto cpu_ctx = paddle::platform::CPUDeviceContext(); - math::SetConstant - constant_functor; - constant_functor(cpu_ctx, out_t, static_cast(0)); - + math::SetConstant constant_functor; + constant_functor(cpu_ctx, out_t, static_cast(0)); // sum all vars to out - auto result = EigenVector::Flatten(*out_t); + auto result = EigenVector::Flatten(*out_t); for (auto& var : vars) { auto& in_t = var->Get(); - auto in = EigenVector::Flatten(in_t); + auto in = EigenVector::Flatten(in_t); result.device(*cpu_ctx.eigen_device()) = result + in; } - if (!FLAGS_communicator_is_sgd_optimizer) { + if (!merge_add) { result.device(*cpu_ctx.eigen_device()) = - result / static_cast(vars.size()); + result / static_cast(vars.size()); } } else if (var0->IsType()) { auto& slr0 = var0->Get(); auto* out_slr = out_var->GetMutable(); out_slr->mutable_rows()->clear(); - out_slr->mutable_value()->mutable_data({{}}, cpu_place); + out_slr->mutable_value()->mutable_data({{}}, cpu_place); std::vector inputs; inputs.reserve(vars.size()); for (auto& var : vars) { inputs.push_back(&var->Get()); } auto dev_ctx = paddle::platform::CPUDeviceContext(); - if (FLAGS_communicator_is_sgd_optimizer) { - math::scatter::MergeAdd - merge_add; + if (merge_add) { + math::scatter::MergeAdd merge_add; merge_add(dev_ctx, inputs, out_slr); } else { - math::scatter::MergeAverage + math::scatter::MergeAverage merge_average; merge_average(dev_ctx, inputs, out_slr); } VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() - << " dims: " << slr0.value().dims(); + << " dims: " << slr0.value().dims() << "; merge add: " << merge_add; } else { PADDLE_THROW("unsupported var type!"); } diff --git a/paddle/fluid/operators/distributed/communicator_test.cc b/paddle/fluid/operators/distributed/communicator_test.cc index 5294ac33d15611a003eeb7971891e8ca85ec6a73..6ffd362e332bddf8b8c09498fc6cc86f56460daf 100644 --- a/paddle/fluid/operators/distributed/communicator_test.cc +++ b/paddle/fluid/operators/distributed/communicator_test.cc @@ -47,7 +47,7 @@ TEST(communicator, merge_lod_tensors) { scope.reset(new framework::Scope()); scope->Var(out_name); for (auto i = 0; i < 10; ++i) { - MergeVars(out_name, in_vars, scope.get()); + MergeVars(out_name, in_vars, scope.get()); } auto &out_tensor = scope->FindVar(out_name)->Get(); auto *out_data = out_tensor.data(); @@ -86,7 +86,7 @@ TEST(communicator, merge_selected_rows) { scope.reset(new framework::Scope()); scope->Var(out_name); for (auto i = 0; i < 10; ++i) { - MergeVars(out_name, in_vars, scope.get()); + MergeVars(out_name, in_vars, scope.get()); } auto &out_slr = scope->FindVar(out_name)->Get(); auto &out_t = out_slr.value(); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 32b6c0428cc5b63a047ac4e6038c23bb8ed17f1e..ad901273ad4a50fbeee65f9e1e77b4d49199401d 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -438,26 +438,40 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, return h; } -VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep, - const std::string& type, - int64_t time_out) { - const auto ch = GetChannel(ep); - - DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch); - +VarHandlePtr GRPCClient::AsyncDistributeNotify( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + int64_t time_out) { + const platform::DeviceContext* p_ctx = &ctx; + const std::string ep_val = ep; + const std::string var_name_val = var_name; + const framework::Scope* p_scope = &scope; + const auto ch = GetChannel(ep_val); const std::string method = kRequestNotify; - VarHandlePtr h( - new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr)); + SendProcessor* s = new SendProcessor(ch); + VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); s->Prepare(h, time_out); - sendrecv::VariableMessage req; - req.set_varname(type); + framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { + auto* var = p_scope->FindVar(var_name_val); - platform::RecordRPCEvent record_event(method); + ::grpc::ByteBuffer req; + SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); - auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + + // stub context + s->response_call_back_ = nullptr; + + platform::RecordRPCEvent record_event(method); + + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req, + &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + }); req_count_++; if (UNLIKELY(platform::IsProfileEnabled())) { diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index 0f1ba6b1e4fb5266eb274a5446c33fad112d242c..2e0599d885103b7cadaf0e93ef7828f1594dcc3e 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -173,20 +173,6 @@ class CheckpointNotifyProcessor : public BaseProcessor { std::unique_ptr stub_; }; -class DistributeNotifyProcessor : public BaseProcessor { - public: - explicit DistributeNotifyProcessor(std::shared_ptr ch) - : BaseProcessor() { - stub_ = sendrecv::SendRecvService::NewStub(ch); - } - - virtual ~DistributeNotifyProcessor() {} - - void ProcessImpl() override {} - sendrecv::VoidMessage reply_; - std::unique_ptr stub_; -}; - class GRPCClient : public RPCClient { public: GRPCClient() : ok_(true), completed_(false), stopped_(false) {} @@ -240,7 +226,8 @@ class GRPCClient : public RPCClient { int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncDistributeNotify( - const std::string& ep, const std::string& type, + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncSendComplete( diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index a4ef70aab6647d4ab81fda187e656c05b87b53e8..e66ce3aa22c877f6e64ecc154fb1fa17d2e4317a 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -400,33 +400,31 @@ class RequestNotify final : public RequestBase { RequestHandler* request_handler, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { request_.reset(new GRPCVariableResponse(request_handler->scope(), - request_handler->dev_ctx())); + request_handler->dev_ctx(), + !request_handler->sync_mode())); int method_id = static_cast(distributed::GrpcMethod::kRequestNotify); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); } - virtual ~RequestNotify() {} - std::string GetReqName() override { return request_->Varname(); } void Process() override { - auto scope = request_->GetMutableLocalScope(); + std::string varname = GetReqName(); + VLOG(4) << "RequestNotify var_name:" << varname; - std::string varname = request_->Varname(); + auto scope = request_->GetMutableLocalScope(); + auto invar = request_->GetVar(); int trainer_id = request_->GetTrainerId(); - - VLOG(4) << "RequestNotify notify: " << varname - << ", trainer id: " << trainer_id; - - request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id); + framework::Variable* outvar = nullptr; + request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); Finish(reply_, &responder_); } protected: - std::shared_ptr request_; sendrecv::VoidMessage reply_; + std::shared_ptr request_; ServerAsyncResponseWriter responder_; }; diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 56362391a25d2e09b366399b496507776f60e67d..4fe88867a89264911093d71bc77449d76e3f5ed8 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -116,24 +116,44 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, row_offset += outs_dims[i][0]; } } - - for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { - auto &send_var_name = rpc_ctx.splited_var_names[i]; - VLOG(4) << "send var name: " << send_var_name; - auto &endpoint = rpc_ctx.epmap[i]; - VLOG(4) << "send var endpoint: " << endpoint; - VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name); - if (NeedSend(*local_scope.get(), send_var_name)) { - VLOG(3) << "sending " << send_var_name << " to " << endpoint; - rets.push_back(rpc_client->AsyncSendVar( - endpoint, cpu_ctx, *local_scope.get(), send_var_name)); - VLOG(4) << "send var " << send_var_name << " async handle done"; - } else { - VLOG(3) << "don't send non-initialized variable: " - << rpc_ctx.splited_var_names[i]; + if (rpc_ctx.use_send_handler) { + for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { + auto &send_var_name = rpc_ctx.splited_var_names[i]; + VLOG(4) << "send var name: " << send_var_name; + auto &endpoint = rpc_ctx.epmap[i]; + VLOG(4) << "send var endpoint: " << endpoint; + VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name); + if (NeedSend(*local_scope.get(), send_var_name)) { + VLOG(3) << "sending " << send_var_name << " to " << endpoint; + rets.push_back(rpc_client->AsyncSendVar( + endpoint, cpu_ctx, *local_scope.get(), send_var_name)); + VLOG(4) << "send var " << send_var_name << " async handle done"; + } else { + VLOG(3) << "don't send non-initialized variable: " + << rpc_ctx.splited_var_names[i]; + } + } + } else { + for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { + for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) { + auto &send_var_name = rpc_ctx.splited_var_names[i]; + VLOG(4) << "send var name: " << send_var_name; + auto &endpoint = rpc_ctx.epmap[j]; + VLOG(4) << "send var endpoint: " << endpoint; + VLOG(4) << "need send: " + << NeedSend(*local_scope.get(), send_var_name); + if (NeedSend(*local_scope.get(), send_var_name)) { + VLOG(3) << "sending " << send_var_name << " to " << endpoint; + rets.push_back(rpc_client->AsyncDistributeNotify( + endpoint, cpu_ctx, *local_scope.get(), send_var_name)); + VLOG(4) << "send var " << send_var_name << " async handle done"; + } else { + VLOG(3) << "don't send non-initialized variable: " + << rpc_ctx.splited_var_names[i]; + } + } } } - } else if (send_var->IsType()) { auto &send_slr = send_var->Get(); auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 8c0bf16497fb98ffad660e04615c5fcac8153c72..d2cb50d444a1a88f43b265121c6761392800c30a 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -63,7 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" #define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" -#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV" +#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 9a7da5f8f924ba37e3bf77b673a71ebf23519275..d0d4f49f49ff50c16c8eae68e6a68225ac837421 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -262,11 +262,25 @@ bool RequestNotifyHandler::Handle(const std::string& varname, const int trainer_id, const std::string& out_var_name, const std::string& table_name) { - VLOG(4) << "RequestNotifyHandler" << varname; - if (varname == LEARNING_RATE_DECAY_MESSAGE) { + VLOG(4) << "RequestNotifyHandler: " << varname; + VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id; + + string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER); + string::Piece var_name_piece = string::Piece(varname); + if (string::Contains(var_name_piece, decay_piece)) { + VLOG(3) << "LearningRate Decay Counter Update"; PADDLE_ENFORCE_NE( lr_decay_block_id, -1, "when lr_decay_block_id = -1, there should be no RPC invoke."); + auto* origin_var = scope_->FindVar(varname); + auto origin_var_tensor = origin_var->Get(); + auto* send_var = scope->FindVar(varname); + auto send_var_tensor = send_var->Get(); + int64_t* origin_value = + origin_var_tensor.mutable_data(origin_var_tensor.place()); + int64_t* send_value = + send_var_tensor.mutable_data(send_var_tensor.place()); + origin_value[0] += send_value[0]; executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); } return true; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 777829557424ba5f3dc0b0f538ee769432da51e7..2071afcfd029ba208b7869de9748278f1dd4128d 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -81,7 +81,8 @@ class RPCClient { int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncDistributeNotify( - const std::string& ep, const std::string& type, + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncSendComplete( diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index eb127bf4ad5a5c9a28210e2fbcdb69b07543f4b9..2f0cc61f2d855690b9228313fd471258d859244a 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -27,12 +27,15 @@ struct RpcContext { RpcContext(const std::string &name, const std::vector &names, const std::vector &emap, - const std::vector §ions, int id) + const std::vector §ions, int id, + bool merge_add_ = true, bool use_send_handler_ = true) : var_name(name), splited_var_names(names), epmap(emap), height_sections(sections), - trainer_id(id) {} + trainer_id(id), + merge_add(merge_add_), + use_send_handler(use_send_handler_) {} RpcContext(const RpcContext &ctx) { var_name = ctx.var_name; @@ -40,6 +43,8 @@ struct RpcContext { epmap = ctx.epmap; height_sections = ctx.height_sections; trainer_id = ctx.trainer_id; + merge_add = ctx.merge_add; + use_send_handler = ctx.use_send_handler; } std::string var_name; @@ -47,6 +52,8 @@ struct RpcContext { std::vector epmap; std::vector height_sections; int trainer_id; + bool merge_add; + bool use_send_handler; }; inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { @@ -70,6 +77,9 @@ inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { os << section << ", "; } os << "]\n"; + + os << "merge add: " << rpc_ctx.merge_add; + os << "; send handler: " << rpc_ctx.use_send_handler << "\n"; os << "}"; return os; } diff --git a/paddle/fluid/operators/distributed_ops/distributed_notify_op.cc b/paddle/fluid/operators/distributed_ops/distributed_notify_op.cc deleted file mode 100644 index 5e15b11655d346b360472d6f206bd1a46d709197..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/distributed_ops/distributed_notify_op.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include // NOLINT -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/distributed/distributed.h" -#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" -#include "paddle/fluid/string/printf.h" - -namespace paddle { -namespace operators { - -class DistributedNotifyOp : public framework::OperatorBase { - public: - DistributedNotifyOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - std::vector epmap = Attr>("epmap"); - std::string type = Attr("type"); - int trainer_id = Attr("trainer_id"); - - distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance(trainer_id); - for (size_t i = 0; i < epmap.size(); i++) { - rpc_client->AsyncDistributeNotify(epmap[i], type); - VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i]; - } - PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient"); - } -}; - -class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() { - AddAttr>("epmap", - "(string vector, default 127.0.0.1:6164)" - "Parameter Server endpoints in the order") - .SetDefault({"127.0.0.1:6164"}); - AddAttr("type", - "(string, default '') indicate the action type"); - AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); - AddComment(R"DOC( -DistributeNotify operator - -This operator will send a signal to listen_and_serve op at -the parameter server. -)DOC"); - } -}; - -class DistributedNotifyOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override {} -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp, - paddle::framework::EmptyGradOpMaker, - ops::DistributedNotifyOpMaker, - ops::DistributedNotifyOpShapeInference); diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index bd49fc0d8a59074965d9517e9bcce34250ad1698..a7476b9b3a38848b812ce132278c3f85fc931a26 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -383,7 +383,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, request_get_no_barrier_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestNotify, - request_notify_handler_.get(), 1); + request_notify_handler_.get(), + FLAGS_rpc_send_thread_num); auto optimize_blocks = Attr>(kOptimizeBlocks); diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 6fff3317f2293f3cb9208ea920175fa6ed82f8c8..de7af02afcad132d0675cbf0d680e9cfbfdfffd3 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -45,6 +45,7 @@ class SendOp : public framework::OperatorBase { auto send_varnames = Attr>("send_varnames"); auto height_sections = Attr>("sections"); + auto use_send_handler = Attr("use_send_handler"); if (send_varnames.size() > 0) { if (ins.size() > 1) { @@ -62,13 +63,27 @@ class SendOp : public framework::OperatorBase { distributed::RPCClient::GetInstance(trainer_id); std::vector rets; - for (size_t i = 0; i < ins.size(); i++) { - if (NeedSend(scope, ins[i])) { - VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - rets.push_back( - rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); - } else { - VLOG(3) << "don't send no-initialied variable: " << ins[i]; + if (use_send_handler) { + for (size_t i = 0; i < ins.size(); i++) { + if (NeedSend(scope, ins[i])) { + VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; + rets.push_back( + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); + } else { + VLOG(3) << "don't send no-initialied variable: " << ins[i]; + } + } + } else { + for (size_t i = 0; i < ins.size(); i++) { + for (size_t j = 0; j < epmap.size(); j++) { + if (NeedSend(scope, ins[i])) { + VLOG(3) << "sending " << ins[i] << " to " << epmap[j]; + rets.push_back(rpc_client->AsyncDistributeNotify(epmap[j], ctx, + scope, ins[i])); + } else { + VLOG(3) << "don't send no-initialied variable: " << ins[i]; + } + } } } for (size_t i = 0; i < rets.size(); i++) { @@ -113,6 +128,15 @@ This operator will send variables to listen_and_serve op at the parameter server "Number of sub-tensors. This must evenly divide " "Input.dims()[axis]") .SetDefault(0); + AddAttr("merge_add", + "(bool, default 0)" + "merge method, true represent add, false represent average") + .SetDefault(false); + AddAttr( + "use_send_handler", + "(bool, default 1)" + "if it's true, use send handler, other wise, use notify handler") + .SetDefault(true); } }; diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..761d57408b9a8f9e52419331bfb0bca5b0135c30 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler_async_decay.py @@ -0,0 +1,143 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import gc +import paddle.fluid as fluid + + +class TranspilerAsyncLRDecayTest(unittest.TestCase): + def setUp(self): + self.trainer_id = 0 + self.trainers = 2 + self.pservers = 2 + # NOTE: we do not actually bind this port + self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175" + self.pserver1_ep = "127.0.0.1:6174" + self.pserver2_ep = "127.0.0.1:6175" + self.sync_mode = False + self.transpiler = None + + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=0.1, + decay_steps=100, + decay_rate=0.99, + staircase=True)) + sgd_optimizer.minimize(avg_cost) + + def get_main_program(self): + main = fluid.Program() + main.random_seed = 1 + with fluid.program_guard(main): + self.net_conf() + self.origin_prog = main.clone() + return main + + def get_trainer(self, config=None): + src = fluid.default_startup_program().clone() + + t = self._transpiler_instance(config) + + trainer_main = t.get_trainer_program(wait_port=False) + trainer_startup = fluid.default_startup_program() + + assert (src.num_blocks == 1) + assert (trainer_startup.num_blocks == src.num_blocks) + + return trainer_main, trainer_startup + + def get_pserver(self, ep, config=None, sync_mode=True): + t = self._transpiler_instance(config, sync_mode) + pserver = t.get_pserver_program(ep) + startup = t.get_startup_program(ep, pserver) + return pserver, startup + + def _transpiler_instance(self, config=None, sync_mode=True): + if not self.transpiler: + main = self.get_main_program() + self.transpiler = fluid.DistributeTranspiler(config=config) + self.transpiler.transpile( + self.trainer_id, + program=main, + pservers=self.pserver_eps, + trainers=self.trainers, + sync_mode=sync_mode) + + return self.transpiler + + def transpiler_test_impl(self): + pserver, startup = self.get_pserver(self.pserver1_ep, sync_mode=False) + pserver2, startup2 = self.get_pserver(self.pserver2_ep, sync_mode=False) + + trainer, trainer_startup = self.get_trainer() + + src = [op.type for op in trainer_startup.global_block().ops] + dst = ['fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', \ + 'uniform_random', 'recv', 'recv', 'fetch_barrier', 'concat'] + self.assertEqual(src, dst) + + self.assertEqual([op.type for op in trainer.global_block().ops], [ + 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', + 'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad', + 'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send', + 'send', 'recv', 'recv', 'concat' + ]) + + self.assertEqual(len(pserver.blocks), 4) + # block0: listen_and_serv + self.assertEqual([op.type for op in pserver.blocks[0].ops], + ["listen_and_serv"]) + # block1: sum,cast,scale,floor,fill_constant,elementwise_pow,scale + self.assertEqual([op.type for op in pserver.blocks[1].ops], [ + "sum", "cast", "scale", "floor", "fill_constant", "elementwise_pow", + "scale" + ]) + + # block1~2: optimize pass + self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"]) + # confirm startup program + self.assertEqual([op.type for op in startup.global_block().ops], [ + "fill_constant", "fill_constant", "fill_constant", "fill_constant", + "uniform_random" + ]) + + def test_transpiler(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.transpiler_test_impl() + # NOTE: run gc.collect to eliminate pybind side objects to + # prevent random double-deallocate when inherited in python. + del self.transpiler + del main + del startup + gc.collect() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index b608543b3e82d28826ba3cbb6ca863ffd74c2291..ddf09d8979b2a4632a7ff90663aa1d51391a4984 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -41,7 +41,7 @@ import logging import numpy as np from .ps_dispatcher import RoundRobin, PSDispatcher -from .. import core, framework, unique_name +from .. import core, framework, unique_name, initializer from ..framework import Program, default_main_program, \ default_startup_program, Block, Parameter, grad_var_name from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed @@ -304,6 +304,7 @@ class DistributeTranspiler(object): PRINT_LOG = True assert (self.config.min_block_size >= 8192) assert (self.config.split_method.__bases__[0] == PSDispatcher) + self.counter_var = None def _transpile_nccl2(self, trainer_id, @@ -631,6 +632,7 @@ class DistributeTranspiler(object): np.random.shuffle(grad_var_mapping_items) self.grad_name_to_send_dummy_out = dict() + for grad_varname, splited_vars in grad_var_mapping_items: eplist = ps_dispatcher.dispatch(splited_vars) @@ -720,6 +722,31 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) fetch_barrier_input.append(send_barrier_out) + else: + lr_ops = self._get_lr_ops() + if len(lr_ops) > 0 and self.counter_var: + decay_dummy_output = program.global_block().create_var( + name=framework.generate_control_dev_var_name()) + if self.config.runtime_split_send_recv: + ## async mode, using communicator to merge and send + send_varnames = [self.counter_var.name] + else: + send_varnames = [] + sections = [] + program.global_block().append_op( + type="send", + inputs={"X": self.counter_var}, + outputs={"Out": decay_dummy_output}, + attrs={ + "epmap": pserver_endpoints, + "sections": sections, + "send_varnames": send_varnames, + "merge_add": True, + "use_send_handler": False, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: + [self.counter_var.name, self.counter_var.name] + }) # step 3: insert recv op to receive parameters from parameter server recv_vars = [] @@ -821,19 +848,6 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) - if not self.sync_mode: - lr_ops = self._get_lr_ops() - if len(lr_ops) > 0: - program.global_block().append_op( - type="distributed_notify", - inputs={}, - outputs={}, - attrs={ - "epmap": pserver_endpoints, - "trainer_id": self.trainer_id, - "type": "LRDECAY@RECV" - }) - self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) if self.has_distributed_lookup_table: @@ -2380,11 +2394,57 @@ class DistributeTranspiler(object): def _get_lr_ops(self): lr_ops = [] block = self.origin_program.global_block() - for op in block.ops: + for index, op in enumerate(block.ops): role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ int(OPT_OP_ROLE_ATTR_VALUE): + if self.sync_mode == False and op.type == 'increment': + inputs = self._get_input_map_from_op( + self.origin_program.global_block().vars, op) + outputs = self._get_output_map_from_op( + self.origin_program.global_block().vars, op) + for key in outputs: + counter_var = outputs[key] + all_trainer_counter_inputs = [ + self.origin_program.global_block().create_var( + name="%s.trainer_%d" % (counter_var.name, id_), + type=counter_var.type, + shape=counter_var.shape, + dtype=counter_var.dtype, + persistable=counter_var.persistable) + for id_ in range(self.trainer_num) + ] + for i, op in enumerate(self.startup_program.global_block() + .ops): + if op.type == 'fill_constant': + for key in op.output_names: + if len(op.output(key)) == 1 and op.output(key)[ + 0] == counter_var.name: + self.startup_program.global_block().ops[ + i]._set_attr( + 'value', + float(0.0 - self.trainer_num)) + for var in all_trainer_counter_inputs: + if var.name == "%s.trainer_%d" % (counter_var.name, + self.trainer_id): + self.counter_var = var + self.startup_program.global_block().create_var( + name=var.name, + type=var.type, + dtype=var.dtype, + shape=var.shape, + persistable=var.persistable, + initializer=initializer.Constant(1)) + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( + ) + block._remove_op(index) + op = block._insert_op( + index, + type='sum', + inputs={'X': all_trainer_counter_inputs}, + outputs=outputs, + attrs={op_role_attr_name: LR_SCHED_OP_ROLE_ATTR_VALUE}) lr_ops.append(op) log("append lr op: ", op.type) return lr_ops