diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 02ffe3651e1deefcf6981c3d304d64b9a01661bf..8898438675687e06fc4389ddcd634dc04e8583bd 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -229,6 +229,22 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { req_count_++; } +void GRPCClient::AsyncCheckpointNotify(const std::string& ep, + const std::string& dir, + int64_t time_out) { + const auto ch = GetChannel(ep); + CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); + s.prepare(time_out); + + sendrecv::CheckpointMessage req; + req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); + req.set_checkpoint_dir(dir); + + auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; +} + void GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 44000c028b499d9ad1a0e0dd40a5e287cd61d143..bc3deff47cec1499056c8b13de5d5e7db9ef2175 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -165,6 +165,20 @@ class FetchBarrierProcessor : public BaseProcessor { std::unique_ptr stub_; }; +class CheckpointNotifyProcessor : public BaseProcessor { + public: + explicit CheckpointNotifyProcessor(std::shared_ptr ch) + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } + + virtual ~CheckpointNotifyProcessor() {} + + virtual void Process() {} + sendrecv::VoidMessage reply_; + std::unique_ptr stub_; +} + class GRPCClient : public RPCClient { public: GRPCClient() {} @@ -193,6 +207,10 @@ class GRPCClient : public RPCClient { const std::string& ep, int64_t time_out = RPCClient::rpc_time_out) override; + void AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = RPCClient::rpc_time_out) override; + void Wait() override; void SendComplete() override; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index e0505c2b9d0903837713d7e0032b01ab091c2e04..69200a01d3c8c9c02ac0126a8fc8341719a94535 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -79,6 +79,7 @@ enum class GrpcMethod { kSendVariable, kGetVariable, kPrefetchVariable, + kCheckpointNotify, }; static const int kGrpcNumMethods = @@ -92,6 +93,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { return "/sendrecv.SendRecvService/GetVariable"; case GrpcMethod::kPrefetchVariable: return "/sendrecv.SendRecvService/PrefetchVariable"; + case GrpcMethod::kCheckpointNotify: + return "/sendrecv.SendRecvService/CheckpointNotify"; } // Shouldn't be reached. diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index cb480accb4ea234cd4ec2edd1ddcc0861984a248..fd33521fd14884032c082cbf054dfa2e3d7352e1 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -43,6 +43,9 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" +#define CHECKPOINT_SAVE_MESSAGE "SAVE" +#define CHECKPOINT_LOAD_MESSAGE "LOAD" + class RPCServer; class RequestHandler { @@ -70,6 +73,11 @@ class RequestHandler { prefetch_var_name_to_prepared_ctx_ = g; } + void SetCheckpointNotifyPreparedCtx( + std::shared_ptr g) { + checkpoint_prepared_ctx_ = g; + } + // Used for async. void SetGradToPreparedCtx( std::unordered_map< @@ -116,6 +124,8 @@ class RequestHandler { std::unordered_map>* prefetch_var_name_to_prepared_ctx_; + // used for checkpoint notify + std::shared_ptr checkpoint_prepared_ctx_; // Used for async. std::unordered_map