提交 b089b809 编写于 作者: T tangwei12

update rpc to add checkpoint notify

上级 12de20f5
......@@ -229,6 +229,22 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s.prepare(time_out);
sendrecv::CheckpointMessage req;
req.set_notify_type(CHECKPOINT_SAVE_MESSAGE);
req.set_checkpoint_dir(dir);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
......
......@@ -165,6 +165,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}
class GRPCClient : public RPCClient {
public:
GRPCClient() {}
......@@ -193,6 +207,10 @@ class GRPCClient : public RPCClient {
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait() override;
void SendComplete() override;
......
......@@ -79,6 +79,7 @@ enum class GrpcMethod {
kSendVariable,
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
};
static const int kGrpcNumMethods =
......@@ -92,6 +93,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
}
// Shouldn't be reached.
......
......@@ -43,6 +43,9 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE"
#define CHECKPOINT_LOAD_MESSAGE "LOAD"
class RPCServer;
class RequestHandler {
......@@ -70,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_ = g;
}
void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
......@@ -116,6 +124,8 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
// Used for async.
std::unordered_map<std::string,
......
......@@ -119,6 +119,12 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true;
}
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {}
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -53,6 +53,10 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out = rpc_time_out) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
......
......@@ -25,6 +25,8 @@ service SendRecvService {
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(CheckpointMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.
......@@ -79,3 +81,8 @@ message VariableMessage {
}
message VoidMessage {}
message CheckpointMessage {
string notify_type = 1;
string checkpoint_dir = 2;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册