提交 b089b809 编写于 作者: T tangwei12

update rpc to add checkpoint notify

上级 12de20f5
...@@ -229,6 +229,22 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { ...@@ -229,6 +229,22 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s.prepare(time_out);
sendrecv::CheckpointMessage req;
req.set_notify_type(CHECKPOINT_SAVE_MESSAGE);
req.set_checkpoint_dir(dir);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() { void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return req_count_ == 0; });
......
...@@ -165,6 +165,20 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -165,6 +165,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() {} GRPCClient() {}
...@@ -193,6 +207,10 @@ class GRPCClient : public RPCClient { ...@@ -193,6 +207,10 @@ class GRPCClient : public RPCClient {
const std::string& ep, const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait() override; void Wait() override;
void SendComplete() override; void SendComplete() override;
......
...@@ -79,6 +79,7 @@ enum class GrpcMethod { ...@@ -79,6 +79,7 @@ enum class GrpcMethod {
kSendVariable, kSendVariable,
kGetVariable, kGetVariable,
kPrefetchVariable, kPrefetchVariable,
kCheckpointNotify,
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
...@@ -92,6 +93,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -92,6 +93,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable: case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable"; return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
} }
// Shouldn't be reached. // Shouldn't be reached.
......
...@@ -43,6 +43,9 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; ...@@ -43,6 +43,9 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE"
#define CHECKPOINT_LOAD_MESSAGE "LOAD"
class RPCServer; class RPCServer;
class RequestHandler { class RequestHandler {
...@@ -70,6 +73,11 @@ class RequestHandler { ...@@ -70,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_ = g; prefetch_var_name_to_prepared_ctx_ = g;
} }
void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}
// Used for async. // Used for async.
void SetGradToPreparedCtx( void SetGradToPreparedCtx(
std::unordered_map< std::unordered_map<
...@@ -116,6 +124,8 @@ class RequestHandler { ...@@ -116,6 +124,8 @@ class RequestHandler {
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_; prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
// Used for async. // Used for async.
std::unordered_map<std::string, std::unordered_map<std::string,
......
...@@ -119,6 +119,12 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -119,6 +119,12 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -53,6 +53,10 @@ class RPCClient { ...@@ -53,6 +53,10 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep, virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0; int64_t time_out = rpc_time_out) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out = rpc_time_out) = 0;
// SendComplete tells all the server that current trainer have no more data // SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue // to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers. // to train with other trainers.
......
...@@ -25,6 +25,8 @@ service SendRecvService { ...@@ -25,6 +25,8 @@ service SendRecvService {
rpc GetVariable(VariableMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids // pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(CheckpointMessage) returns (VoidMessage) {}
} }
// VariableMessage is serialized paddle variable message. // VariableMessage is serialized paddle variable message.
...@@ -79,3 +81,8 @@ message VariableMessage { ...@@ -79,3 +81,8 @@ message VariableMessage {
} }
message VoidMessage {} message VoidMessage {}
message CheckpointMessage {
string notify_type = 1;
string checkpoint_dir = 2;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册