未验证 提交 b4a3b750 编写于 作者: 1 123malin 提交者: GitHub

bug fix: invalid learning rate decay in pserver async mode (#20325)

* bug fix: invalid learning rate decay in pserver async mode
上级 cadc6a97
...@@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
return h; return h;
} }
VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep,
const std::string& type,
int64_t time_out) {
const auto ch = GetChannel(ep);
DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch);
const std::string method = kRequestNotify;
VarHandlePtr h(
new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(type);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
bool GRPCClient::Wait() { bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
......
...@@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor { ...@@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class DistributeNotifyProcessor : public BaseProcessor {
public:
explicit DistributeNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~DistributeNotifyProcessor() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() : ok_(true), completed_(false), stopped_(false) {} GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
...@@ -225,6 +239,10 @@ class GRPCClient : public RPCClient { ...@@ -225,6 +239,10 @@ class GRPCClient : public RPCClient {
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete( VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
......
...@@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
class RequestNotify final : public RequestBase {
public:
explicit RequestNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string varname = request_->Varname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestNotify notify: " << varname
<< ", trainer id: " << trainer_id;
request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is waiting server ready"; VLOG(4) << "AsyncGRPCServer is waiting server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
...@@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) { } else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) {
b = new RequestNotify(&service_, cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
......
...@@ -84,10 +84,11 @@ enum class GrpcMethod { ...@@ -84,10 +84,11 @@ enum class GrpcMethod {
kGetVariableNoBarrier, kGetVariableNoBarrier,
kGetMonomerVariable, kGetMonomerVariable,
kGetMonomerBarrier, kGetMonomerBarrier,
kRequestNotify,
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetMonomerBarrier) + 1; static_cast<int>(GrpcMethod::kRequestNotify) + 1;
inline const char* GrpcMethodName(GrpcMethod id) { inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) { switch (id) {
...@@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/PrefetchVariable"; return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify: case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify"; return "/sendrecv.SendRecvService/CheckpointNotify";
case GrpcMethod::kRequestNotify:
return "/sendrecv.SendRecvService/DistributeNotify";
} }
// Shouldn't be reached. // Shouldn't be reached.
......
...@@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch"; ...@@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier"; constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kRequestNotify[] = "RequestNotify";
constexpr char kSendRPC[] = "SendRPC"; constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC"; constexpr char kGetRPC[] = "GetRPC";
...@@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; ...@@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" #define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...@@ -188,6 +190,11 @@ class RequestHandler { ...@@ -188,6 +190,11 @@ class RequestHandler {
sparse_grad_to_param_ = g; sparse_grad_to_param_ = g;
} }
void SetLrDecayPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
lr_decay_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes. // Get attributes.
...@@ -238,6 +245,8 @@ class RequestHandler { ...@@ -238,6 +245,8 @@ class RequestHandler {
grad_to_prepared_ctx_; grad_to_prepared_ctx_;
std::unordered_map<std::string, std::string>* sparse_grad_to_param_; std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
// used for lr decay
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
......
...@@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestNotifyHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler" << varname;
if (varname == LEARNING_RATE_DECAY_MESSAGE) {
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke.");
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
}
return true;
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <time.h> #include <time.h>
#include <functional> #include <functional>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler { ...@@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler {
int checkpoint_notify_id; int checkpoint_notify_id;
}; };
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(bool sync_mode, int lr_decay_block_id)
: RequestHandler(sync_mode) {
this->lr_decay_block_id = lr_decay_block_id;
}
virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
int lr_decay_block_id;
};
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -80,6 +80,10 @@ class RPCClient { ...@@ -80,6 +80,10 @@ class RPCClient {
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete( virtual VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
......
...@@ -28,6 +28,7 @@ service SendRecvService { ...@@ -28,6 +28,7 @@ service SendRecvService {
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
rpc DistributeNotify(VariableMessage) returns (VoidMessage) {}
rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {} rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
class DistributedNotifyOp : public framework::OperatorBase {
public:
DistributedNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string type = Attr<std::string>("type");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) {
rpc_client->AsyncDistributeNotify(epmap[i], type);
VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i];
}
PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient");
}
};
class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>("type",
"(string, default '') indicate the action type");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC(
DistributeNotify operator
This operator will send a signal to listen_and_serve op at
the parameter server.
)DOC");
}
};
class DistributedNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::DistributedNotifyOpMaker,
ops::DistributedNotifyOpShapeInference);
...@@ -298,6 +298,7 @@ static void FillRequestCtx( ...@@ -298,6 +298,7 @@ static void FillRequestCtx(
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>
*sparse_grad_name_to_param_name, *sparse_grad_name_to_param_name,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx, std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_ctx,
distributed::RPCServer *rpc_server) { distributed::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
h->SetDevCtx(dev_ctx); h->SetDevCtx(dev_ctx);
...@@ -307,6 +308,7 @@ static void FillRequestCtx( ...@@ -307,6 +308,7 @@ static void FillRequestCtx(
h->SetSparseGradToParam(sparse_grad_name_to_param_name); h->SetSparseGradToParam(sparse_grad_name_to_param_name);
h->SetRPCServer(rpc_server); h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
h->SetLrDecayPreparedCtx(lr_decay_ctx);
} }
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames, void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
...@@ -345,10 +347,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -345,10 +347,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_block_id = Attr<int>(kCheckpointBlockId); int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode
<< ", fan_in:" << fan_in << ", end_point:" << endpoint << ", fan_in:" << fan_in << ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id; << ", checkpoint_block_id: " << checkpoint_block_id
<< ", lr_decay_block_id: " << lr_decay_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
...@@ -362,6 +366,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -362,6 +366,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sync_mode, checkpoint_block_id)); sync_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset( request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler()); new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(
new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), request_send_handler_.get(),
...@@ -376,6 +382,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -376,6 +382,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_checkpoint_handler_.get()); request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get()); request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), 1);
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
...@@ -391,6 +399,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -391,6 +399,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
ckpt_pre_context = std::move(ctx); ckpt_pre_context = std::move(ctx);
} }
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_context = nullptr;
if (lr_decay_block_id != -1) {
auto ctx = executor.Prepare(*program, lr_decay_block_id);
// see: https://stackoverflow.com/a/14856553
lr_decay_context = std::move(ctx);
}
// prepare for prefetch // prepare for prefetch
std::vector<int> prefetch_block_id_list; std::vector<int> prefetch_block_id_list;
std::unordered_map<int, std::string> block_id_to_prefetch_var_name; std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
...@@ -435,16 +450,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -435,16 +450,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sparse_grad_name_to_param_name[pieces[0]] = pieces[1]; sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
} }
auto f = std::bind( auto f =
FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
program, &prefetch_var_name_to_prepared_ctx, &executor, program, &prefetch_var_name_to_prepared_ctx,
&sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get()); &sparse_grad_name_to_param_name, ckpt_pre_context,
lr_decay_context, rpc_service_.get());
f(request_send_handler_.get()); f(request_send_handler_.get());
f(request_get_handler_.get()); f(request_get_handler_.get());
f(request_prefetch_handler_.get()); f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get()); f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get()); f(request_get_no_barrier_handler_.get());
f(request_notify_handler_.get());
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -522,6 +539,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -522,6 +539,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>(kCheckpointBlockId, AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.") "BolckID to run save checkpoint on pserer.")
.SetDefault(-1); .SetDefault(-1);
AddAttr<int>(kLRDecayBlockId, "BolckID to run lr decay on pserer.")
.SetDefault(-1);
} }
}; };
......
...@@ -37,6 +37,7 @@ namespace operators { ...@@ -37,6 +37,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id"; constexpr char kCheckpointBlockId[] = "checkpint_block_id";
constexpr char kLRDecayBlockId[] = "lr_decay_block_id";
constexpr char kSparseGradToParam[] = "sparse_grad_to_param"; constexpr char kSparseGradToParam[] = "sparse_grad_to_param";
void RunServer(std::shared_ptr<distributed::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
...@@ -97,6 +98,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -97,6 +98,7 @@ class ListenAndServOp : public framework::OperatorBase {
request_prefetch_handler_; request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler> mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_; request_checkpoint_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_notify_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_; mutable std::vector<std::string> sparse_vars_;
......
...@@ -103,8 +103,16 @@ class TestDistCTR2x2(TestDistRunnerBase): ...@@ -103,8 +103,16 @@ class TestDistCTR2x2(TestDistRunnerBase):
if use_l2_decay: if use_l2_decay:
regularization = fluid.regularizer.L2DecayRegularizer( regularization = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-1) regularization_coeff=1e-1)
use_lr_decay = bool(os.getenv('LR_DECAY', 0))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001, lr = 0.0001
if use_lr_decay:
lr = fluid.layers.exponential_decay(
learning_rate=0.0001,
decay_steps=10000,
decay_rate=0.999,
staircase=True)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=lr,
regularization=regularization) regularization=regularization)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
......
...@@ -80,6 +80,28 @@ class TestDistCTR2x2_ASYNC(TestDistBase): ...@@ -80,6 +80,28 @@ class TestDistCTR2x2_ASYNC(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
"LR_DECAY": "1"
}
self.check_with_place(
"dist_ctr.py",
delta=100,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistCTR2x2_ASYNC2(TestDistBase): class TestDistCTR2x2_ASYNC2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
......
...@@ -818,6 +818,18 @@ class DistributeTranspiler(object): ...@@ -818,6 +818,18 @@ class DistributeTranspiler(object):
self._update_remote_sparse_update_op(program, self._update_remote_sparse_update_op(program,
need_sparse_update_params) need_sparse_update_params)
if not self.sync_mode:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
program.global_block().append_op(
type="distributed_notify",
inputs={},
outputs={},
attrs={
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
"type": "LRDECAY@RECV"
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
...@@ -1125,6 +1137,8 @@ class DistributeTranspiler(object): ...@@ -1125,6 +1137,8 @@ class DistributeTranspiler(object):
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
# record optimize blocks and we can run them on pserver parallel # record optimize blocks and we can run them on pserver parallel
optimize_blocks = [] optimize_blocks = []
lr_decay_block_id = -1
if len(lr_ops) > 0: if len(lr_ops) > 0:
lr_decay_block = pserver_program._create_block( lr_decay_block = pserver_program._create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
...@@ -1134,6 +1148,7 @@ class DistributeTranspiler(object): ...@@ -1134,6 +1148,7 @@ class DistributeTranspiler(object):
# append sub blocks to pserver_program in lr_decay_op # append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(cloned_op, pserver_program, __clone_lr_op_sub_block__(cloned_op, pserver_program,
lr_decay_block) lr_decay_block)
lr_decay_block_id = lr_decay_block.idx
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
...@@ -1211,6 +1226,7 @@ class DistributeTranspiler(object): ...@@ -1211,6 +1226,7 @@ class DistributeTranspiler(object):
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param, "sparse_grad_to_param": sparse_grad_to_param,
"lr_decay_block_id": lr_decay_block_id,
} }
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册