未验证 提交 b4a3b750 编写于 作者: 1 123malin 提交者: GitHub

bug fix: invalid learning rate decay in pserver async mode (#20325)

* bug fix: invalid learning rate decay in pserver async mode
上级 cadc6a97
......@@ -438,6 +438,35 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
return h;
}
VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep,
const std::string& type,
int64_t time_out) {
const auto ch = GetChannel(ep);
DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch);
const std::string method = kRequestNotify;
VarHandlePtr h(
new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(type);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
......
......@@ -173,6 +173,20 @@ class CheckpointNotifyProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class DistributeNotifyProcessor : public BaseProcessor {
public:
explicit DistributeNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~DistributeNotifyProcessor() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient {
public:
GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
......@@ -225,6 +239,10 @@ class GRPCClient : public RPCClient {
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
......
......@@ -393,6 +393,43 @@ class RequestCheckpointNotify final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestNotify final : public RequestBase {
public:
explicit RequestNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string varname = request_->Varname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestNotify notify: " << varname
<< ", trainer id: " << trainer_id;
request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is waiting server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
......@@ -526,6 +563,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) {
b = new RequestNotify(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not supported rpc");
}
......
......@@ -84,10 +84,11 @@ enum class GrpcMethod {
kGetVariableNoBarrier,
kGetMonomerVariable,
kGetMonomerBarrier,
kRequestNotify,
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetMonomerBarrier) + 1;
static_cast<int>(GrpcMethod::kRequestNotify) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
......@@ -105,6 +106,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
case GrpcMethod::kRequestNotify:
return "/sendrecv.SendRecvService/DistributeNotify";
}
// Shouldn't be reached.
......
......@@ -45,6 +45,7 @@ constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kRequestNotify[] = "RequestNotify";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
......@@ -62,6 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......@@ -188,6 +190,11 @@ class RequestHandler {
sparse_grad_to_param_ = g;
}
void SetLrDecayPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
lr_decay_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
......@@ -238,6 +245,8 @@ class RequestHandler {
grad_to_prepared_ctx_;
std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
// used for lr decay
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
RPCServer* rpc_server_;
};
......
......@@ -251,6 +251,23 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
return true;
}
bool RequestNotifyHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler" << varname;
if (varname == LEARNING_RATE_DECAY_MESSAGE) {
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke.");
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
}
return true;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -17,6 +17,7 @@
#include <time.h>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
......@@ -126,6 +127,22 @@ class RequestCheckpointHandler final : public RequestHandler {
int checkpoint_notify_id;
};
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(bool sync_mode, int lr_decay_block_id)
: RequestHandler(sync_mode) {
this->lr_decay_block_id = lr_decay_block_id;
}
virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
private:
int lr_decay_block_id;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -80,6 +80,10 @@ class RPCClient {
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
......
......@@ -28,6 +28,7 @@ service SendRecvService {
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
rpc DistributeNotify(VariableMessage) returns (VoidMessage) {}
rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
class DistributedNotifyOp : public framework::OperatorBase {
public:
DistributedNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string type = Attr<std::string>("type");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) {
rpc_client->AsyncDistributeNotify(epmap[i], type);
VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i];
}
PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient");
}
};
class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>("type",
"(string, default '') indicate the action type");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC(
DistributeNotify operator
This operator will send a signal to listen_and_serve op at
the parameter server.
)DOC");
}
};
class DistributedNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::DistributedNotifyOpMaker,
ops::DistributedNotifyOpShapeInference);
......@@ -298,6 +298,7 @@ static void FillRequestCtx(
std::unordered_map<std::string, std::string>
*sparse_grad_name_to_param_name,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_ctx,
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
......@@ -307,6 +308,7 @@ static void FillRequestCtx(
h->SetSparseGradToParam(sparse_grad_name_to_param_name);
h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
h->SetLrDecayPreparedCtx(lr_decay_ctx);
}
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
......@@ -345,10 +347,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
VLOG(4) << "pserver_id: " << pserver_id << ", sync_mode:" << sync_mode
<< ", fan_in:" << fan_in << ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id;
<< ", checkpoint_block_id: " << checkpoint_block_id
<< ", lr_decay_block_id: " << lr_decay_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
......@@ -362,6 +366,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sync_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(
new distributed::RequestNotifyHandler(sync_mode, lr_decay_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
......@@ -376,6 +382,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), 1);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -391,6 +399,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
ckpt_pre_context = std::move(ctx);
}
std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_context = nullptr;
if (lr_decay_block_id != -1) {
auto ctx = executor.Prepare(*program, lr_decay_block_id);
// see: https://stackoverflow.com/a/14856553
lr_decay_context = std::move(ctx);
}
// prepare for prefetch
std::vector<int> prefetch_block_id_list;
std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
......@@ -435,16 +450,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
}
auto f = std::bind(
FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor,
program, &prefetch_var_name_to_prepared_ctx,
&sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get());
auto f =
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&executor, program, &prefetch_var_name_to_prepared_ctx,
&sparse_grad_name_to_param_name, ckpt_pre_context,
lr_decay_context, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
f(request_notify_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......@@ -522,6 +539,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.")
.SetDefault(-1);
AddAttr<int>(kLRDecayBlockId, "BolckID to run lr decay on pserer.")
.SetDefault(-1);
}
};
......
......@@ -37,6 +37,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
constexpr char kLRDecayBlockId[] = "lr_decay_block_id";
constexpr char kSparseGradToParam[] = "sparse_grad_to_param";
void RunServer(std::shared_ptr<distributed::RPCServer> service);
......@@ -97,6 +98,7 @@ class ListenAndServOp : public framework::OperatorBase {
request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_notify_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
......
......@@ -103,8 +103,16 @@ class TestDistCTR2x2(TestDistRunnerBase):
if use_l2_decay:
regularization = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-1)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001,
use_lr_decay = bool(os.getenv('LR_DECAY', 0))
lr = 0.0001
if use_lr_decay:
lr = fluid.layers.exponential_decay(
learning_rate=0.0001,
decay_steps=10000,
decay_rate=0.999,
staircase=True)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=lr,
regularization=regularization)
sgd_optimizer.minimize(avg_cost)
......
......@@ -80,6 +80,28 @@ class TestDistCTR2x2_ASYNC(TestDistBase):
log_name=flag_name)
class TestDistCTR2x2_ASYNCWithLRDecay2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
"LR_DECAY": "1"
}
self.check_with_place(
"dist_ctr.py",
delta=100,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistCTR2x2_ASYNC2(TestDistBase):
def _setup_config(self):
self._sync_mode = False
......
......@@ -818,6 +818,18 @@ class DistributeTranspiler(object):
self._update_remote_sparse_update_op(program,
need_sparse_update_params)
if not self.sync_mode:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
program.global_block().append_op(
type="distributed_notify",
inputs={},
outputs={},
attrs={
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
"type": "LRDECAY@RECV"
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
......@@ -1125,6 +1137,8 @@ class DistributeTranspiler(object):
lr_ops = self._get_lr_ops()
# record optimize blocks and we can run them on pserver parallel
optimize_blocks = []
lr_decay_block_id = -1
if len(lr_ops) > 0:
lr_decay_block = pserver_program._create_block(
pserver_program.num_blocks - 1)
......@@ -1134,6 +1148,7 @@ class DistributeTranspiler(object):
# append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(cloned_op, pserver_program,
lr_decay_block)
lr_decay_block_id = lr_decay_block.idx
# append op to the current block
grad_to_block_id = []
......@@ -1211,6 +1226,7 @@ class DistributeTranspiler(object):
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param,
"lr_decay_block_id": lr_decay_block_id,
}
if self.has_distributed_lookup_table:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册