diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 9a25ec8fdb4599ea9128bed151a525c13d4127a0..17476ab513b55226f0b6db622841fda80850bda8 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -233,12 +233,20 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, const std::string& dir, int64_t time_out) { const auto ch = GetChannel(ep); + CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); s->Prepare(time_out); + s->response_call_back_ = nullptr; - sendrecv::CheckpointMessage req; - req.set_notify_type(CHECKPOINT_SAVE_MESSAGE); - req.set_checkpoint_dir(dir); + sendrecv::VariableMessage req; + req.set_varname(CHECKPOINT_SAVE_MESSAGE); + req.out_varname(dir); + + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/CheckpointNotify", req, + &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index cc6529cea7451e1bdd334d41b40d63f5eed6aad4..f5800cdb7f7e124426bbb970d00b429894a110b4 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -26,7 +26,7 @@ service SendRecvService { // pre-fetch variable by given variable name and Ids rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} - rpc CheckpointNotify(CheckpointMessage) returns (VoidMessage) {} + rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. @@ -83,6 +83,7 @@ message VariableMessage { message VoidMessage {} message CheckpointMessage { - string notify_type = 1; - string checkpoint_dir = 2; + string varname = 1; + string notify_type = 2; + string checkpoint_dir = 3; }