diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index de61400fdf635135dc6df65524632f0f74c817f9..32b6c0428cc5b63a047ac4e6038c23bb8ed17f1e 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, return h; } +VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep, + const std::string& type, + int64_t time_out) { + const auto ch = GetChannel(ep); + + DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch); + + const std::string method = kRequestNotify; + + VarHandlePtr h( + new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr)); + s->Prepare(h, time_out); + + sendrecv::VariableMessage req; + req.set_varname(type); + + platform::RecordRPCEvent record_event(method); + + auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + + return h; +} + bool GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index ad2f04a6d1dda34e35b67b21dce8ac612ff697a0..0f1ba6b1e4fb5266eb274a5446c33fad112d242c 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor { std::unique_ptr stub_; }; +class DistributeNotifyProcessor : public BaseProcessor { + public: + explicit DistributeNotifyProcessor(std::shared_ptr ch) + : BaseProcessor() { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } + + virtual ~DistributeNotifyProcessor() {} + + void ProcessImpl() override {} + sendrecv::VoidMessage reply_; + std::unique_ptr stub_; +}; + class GRPCClient : public RPCClient { public: GRPCClient() : ok_(true), completed_(false), stopped_(false) {} @@ -225,6 +239,10 @@ class GRPCClient : public RPCClient { const std::string& ep, const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncDistributeNotify( + const std::string& ep, const std::string& type, + int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncSendComplete( const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index 75526bed0f0eadada65279ec05757da7a469f984..a4ef70aab6647d4ab81fda187e656c05b87b53e8 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase { ServerAsyncResponseWriter responder_; }; +class RequestNotify final : public RequestBase { + public: + explicit RequestNotify(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { + request_.reset(new GRPCVariableResponse(request_handler->scope(), + request_handler->dev_ctx())); + int method_id = static_cast(distributed::GrpcMethod::kRequestNotify); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); + } + + virtual ~RequestNotify() {} + + std::string GetReqName() override { return request_->Varname(); } + + void Process() override { + auto scope = request_->GetMutableLocalScope(); + + std::string varname = request_->Varname(); + int trainer_id = request_->GetTrainerId(); + + VLOG(4) << "RequestNotify notify: " << varname + << ", trainer id: " << trainer_id; + + request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id); + Finish(reply_, &responder_); + } + + protected: + std::shared_ptr request_; + sendrecv::VoidMessage reply_; + ServerAsyncResponseWriter responder_; +}; + void AsyncGRPCServer::WaitServerReady() { VLOG(4) << "AsyncGRPCServer is waiting server ready"; std::unique_lock lock(this->mutex_ready_); @@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, b = new RequestPrefetch(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestCheckpoint) { b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); + } else if (rpc_name == kRequestNotify) { + b = new RequestNotify(&service_, cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); } diff --git a/paddle/fluid/operators/distributed/grpc/grpc_service.h b/paddle/fluid/operators/distributed/grpc/grpc_service.h index 2965fe4490bedd0253682f0aef44e096232fc2fc..45152293896e86806fe87324416c2588796558ba 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_service.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_service.h @@ -84,10 +84,11 @@ enum class GrpcMethod { kGetVariableNoBarrier, kGetMonomerVariable, kGetMonomerBarrier, + kRequestNotify, }; static const int kGrpcNumMethods = - static_cast(GrpcMethod::kGetMonomerBarrier) + 1; + static_cast(GrpcMethod::kRequestNotify) + 1; inline const char* GrpcMethodName(GrpcMethod id) { switch (id) { @@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { return "/sendrecv.SendRecvService/PrefetchVariable"; case GrpcMethod::kCheckpointNotify: return "/sendrecv.SendRecvService/CheckpointNotify"; + case GrpcMethod::kRequestNotify: + return "/sendrecv.SendRecvService/DistributeNotify"; } // Shouldn't be reached. diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 22083d92ed42f0e4f13768b0fa4d3254171c0d4d..8c0bf16497fb98ffad660e04615c5fcac8153c72 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier"; +constexpr char kRequestNotify[] = "RequestNotify"; constexpr char kSendRPC[] = "SendRPC"; constexpr char kGetRPC[] = "GetRPC"; @@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" #define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" +#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" @@ -188,6 +190,11 @@ class RequestHandler { sparse_grad_to_param_ = g; } + void SetLrDecayPreparedCtx( + std::shared_ptr g) { + lr_decay_prepared_ctx_ = g; + } + void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } // Get attributes. @@ -238,6 +245,8 @@ class RequestHandler { grad_to_prepared_ctx_; std::unordered_map* sparse_grad_to_param_; + // used for lr decay + std::shared_ptr lr_decay_prepared_ctx_; RPCServer* rpc_server_; }; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index ca150f70c74de3bb8ae5a1e8a09673b8c43558ad..96098a4e22632ca57c2efbee5374f0c8806bfbe7 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, return true; } +bool RequestNotifyHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, + const int trainer_id, + const std::string& out_var_name, + const std::string& table_name) { + VLOG(4) << "RequestNotifyHandler" << varname; + if (varname == LEARNING_RATE_DECAY_MESSAGE) { + PADDLE_ENFORCE_NE( + lr_decay_block_id, -1, + "when lr_decay_block_id = -1, there should be no RPC invoke."); + executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); + } + return true; +} + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h index f3c1b24526b8b28033c0c979f74d44a3d7a94201..b13f0269ce6304de1e58c778a0800b38b462c73a 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler { int checkpoint_notify_id; }; +class RequestNotifyHandler final : public RequestHandler { + public: + explicit RequestNotifyHandler(bool sync_mode, int lr_decay_block_id) + : RequestHandler(sync_mode) { + this->lr_decay_block_id = lr_decay_block_id; + } + virtual ~RequestNotifyHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar, + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; + + private: + int lr_decay_block_id; +}; + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index d0b971e0cb1bde477fed9264b5ecee7b249a2c09..777829557424ba5f3dc0b0f538ee769432da51e7 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -80,6 +80,10 @@ class RPCClient { const std::string& ep, const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncDistributeNotify( + const std::string& ep, const std::string& type, + int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncSendComplete( const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; diff --git a/paddle/fluid/operators/distributed/send_recv.proto.in b/paddle/fluid/operators/distributed/send_recv.proto.in index 6303667884361be050ac62c604274c87caa72444..0337b72181cf9f612fe56ae24bad39775bfcde28 100644 --- a/paddle/fluid/operators/distributed/send_recv.proto.in +++ b/paddle/fluid/operators/distributed/send_recv.proto.in @@ -28,6 +28,7 @@ service SendRecvService { rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} + rpc DistributeNotify(VariableMessage) returns (VoidMessage) {} rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {} rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} diff --git a/paddle/fluid/operators/distributed_ops/distributed_notify_op.cc b/paddle/fluid/operators/distributed_ops/distributed_notify_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e15b11655d346b360472d6f206bd1a46d709197 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/distributed_notify_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" +#include "paddle/fluid/string/printf.h" + +namespace paddle { +namespace operators { + +class DistributedNotifyOp : public framework::OperatorBase { + public: + DistributedNotifyOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + std::vector epmap = Attr>("epmap"); + std::string type = Attr("type"); + int trainer_id = Attr("trainer_id"); + + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(trainer_id); + for (size_t i = 0; i < epmap.size(); i++) { + rpc_client->AsyncDistributeNotify(epmap[i], type); + VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i]; + } + PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient"); + } +}; + +class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddAttr>("epmap", + "(string vector, default 127.0.0.1:6164)" + "Parameter Server endpoints in the order") + .SetDefault({"127.0.0.1:6164"}); + AddAttr("type", + "(string, default '') indicate the action type"); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddComment(R"DOC( +DistributeNotify operator + +This operator will send a signal to listen_and_serve op at +the parameter server. +)DOC"); + } +}; + +class DistributedNotifyOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp, + paddle::framework::EmptyGradOpMaker, + ops::DistributedNotifyOpMaker, + ops::DistributedNotifyOpShapeInference); diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index ba55d5c2f3df04334a0702919099b938c607cfcf..bd49fc0d8a59074965d9517e9bcce34250ad1698 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -298,6 +298,7 @@ static void FillRequestCtx( std::unordered_map *sparse_grad_name_to_param_name, std::shared_ptr checkpoint_ctx, + std::shared_ptr lr_decay_ctx, distributed::RPCServer *rpc_server) { h->SetScope(scope); h->SetDevCtx(dev_ctx); @@ -307,6 +308,7 @@ static void FillRequestCtx( h->SetSparseGradToParam(sparse_grad_name_to_param_name); h->SetRPCServer(rpc_server); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); + h->SetLrDecayPreparedCtx(lr_decay_ctx); } void ListenAndServOp::CacheVarsType(const std::vector &varnames, @@ -345,10 +347,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); int checkpoint_block_id = Attr(kCheckpointBlockId); + int lr_decay_block_id = Attr(kLRDecayBlockId); VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode << ", fan_in:" << fan_in << ", end_point:" << endpoint - << ", checkpoint_block_id: " << checkpoint_block_id; + << ", checkpoint_block_id: " << checkpoint_block_id + << ", lr_decay_block_id: " << lr_decay_block_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -362,6 +366,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, sync_mode, checkpoint_block_id)); request_get_no_barrier_handler_.reset( new distributed::RequestGetNoBarrierHandler()); + request_notify_handler_.reset( + new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get(), @@ -376,6 +382,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_checkpoint_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, request_get_no_barrier_handler_.get()); + rpc_service_->RegisterRPC(distributed::kRequestNotify, + request_notify_handler_.get(), 1); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -391,6 +399,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ckpt_pre_context = std::move(ctx); } + std::shared_ptr lr_decay_context = nullptr; + if (lr_decay_block_id != -1) { + auto ctx = executor.Prepare(*program, lr_decay_block_id); + // see: https://stackoverflow.com/a/14856553 + lr_decay_context = std::move(ctx); + } + // prepare for prefetch std::vector prefetch_block_id_list; std::unordered_map block_id_to_prefetch_var_name; @@ -435,16 +450,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, sparse_grad_name_to_param_name[pieces[0]] = pieces[1]; } - auto f = std::bind( - FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, - program, &prefetch_var_name_to_prepared_ctx, - &sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get()); + auto f = + std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, + &executor, program, &prefetch_var_name_to_prepared_ctx, + &sparse_grad_name_to_param_name, ckpt_pre_context, + lr_decay_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); f(request_prefetch_handler_.get()); f(request_checkpoint_handler_.get()); f(request_get_no_barrier_handler_.get()); + f(request_notify_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); @@ -522,6 +539,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr(kCheckpointBlockId, "BolckID to run save checkpoint on pserer.") .SetDefault(-1); + AddAttr(kLRDecayBlockId, "BolckID to run lr decay on pserer.") + .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h index 1cf2130d7a593077d1145b4f3be379c32557dd53..369743dfb2392c029bc3b671e519aefbbdd2b6b7 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h @@ -37,6 +37,7 @@ namespace operators { constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kCheckpointBlockId[] = "checkpint_block_id"; +constexpr char kLRDecayBlockId[] = "lr_decay_block_id"; constexpr char kSparseGradToParam[] = "sparse_grad_to_param"; void RunServer(std::shared_ptr service); @@ -97,6 +98,7 @@ class ListenAndServOp : public framework::OperatorBase { request_prefetch_handler_; mutable std::shared_ptr request_checkpoint_handler_; + mutable std::shared_ptr request_notify_handler_; mutable std::shared_ptr server_thread_; mutable std::vector sparse_vars_; diff --git a/python/paddle/fluid/tests/unittests/dist_ctr.py b/python/paddle/fluid/tests/unittests/dist_ctr.py index fd09d47258fdfbf6d4a285df7d53c81f7489f39e..c5aae1eef180e51255f5b57a6c680155f3902cda 100644 --- a/python/paddle/fluid/tests/unittests/dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_ctr.py @@ -103,8 +103,16 @@ class TestDistCTR2x2(TestDistRunnerBase): if use_l2_decay: regularization = fluid.regularizer.L2DecayRegularizer( regularization_coeff=1e-1) - - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001, + use_lr_decay = bool(os.getenv('LR_DECAY', 0)) + lr = 0.0001 + if use_lr_decay: + lr = fluid.layers.exponential_decay( + learning_rate=0.0001, + decay_steps=10000, + decay_rate=0.999, + staircase=True) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=lr, regularization=regularization) sgd_optimizer.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index 3bdf28bf9acce740351d137fd95cc0d5902a9eb0..91947ded35330c7b35eb4560daa98c53653a13f4 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -80,6 +80,28 @@ class TestDistCTR2x2_ASYNC(TestDistBase): log_name=flag_name) +class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._hogwild_mode = True + self._enforce_place = "CPU" + + def test_dist_ctr(self): + need_envs = { + "FLAGS_communicator_send_queue_size": "2", + "FLAGS_communicator_max_merge_var_num": "2", + "FLAGS_communicator_max_send_grad_num_before_recv": "2", + "LR_DECAY": "1" + } + + self.check_with_place( + "dist_ctr.py", + delta=100, + check_error_log=True, + need_envs=need_envs, + log_name=flag_name) + + class TestDistCTR2x2_ASYNC2(TestDistBase): def _setup_config(self): self._sync_mode = False diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 73e0316c79af41fd4b35a664813acd9f7223fe7c..2f809c9f1c79fb52d6c7bdc82bef37d2a0398afc 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -818,6 +818,18 @@ class DistributeTranspiler(object): self._update_remote_sparse_update_op(program, need_sparse_update_params) + if not self.sync_mode: + lr_ops = self._get_lr_ops() + if len(lr_ops) > 0: + program.global_block().append_op( + type="distributed_notify", + inputs={}, + outputs={}, + attrs={ + "epmap": pserver_endpoints, + "trainer_id": self.trainer_id, + "type": "LRDECAY@RECV" + }) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) @@ -1125,6 +1137,8 @@ class DistributeTranspiler(object): lr_ops = self._get_lr_ops() # record optimize blocks and we can run them on pserver parallel optimize_blocks = [] + + lr_decay_block_id = -1 if len(lr_ops) > 0: lr_decay_block = pserver_program._create_block( pserver_program.num_blocks - 1) @@ -1134,6 +1148,7 @@ class DistributeTranspiler(object): # append sub blocks to pserver_program in lr_decay_op __clone_lr_op_sub_block__(cloned_op, pserver_program, lr_decay_block) + lr_decay_block_id = lr_decay_block.idx # append op to the current block grad_to_block_id = [] @@ -1211,6 +1226,7 @@ class DistributeTranspiler(object): "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, "sparse_grad_to_param": sparse_grad_to_param, + "lr_decay_block_id": lr_decay_block_id, } if self.has_distributed_lookup_table: