From 925e2324b3b7920d66217ec2c870010f24c0967a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 18 Jun 2018 10:50:12 +0800 Subject: [PATCH] add RequestCheckpointNotify in grpc --- paddle/fluid/operators/detail/grpc_server.cc | 33 ++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 2d34f8583..3e5625fa2 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -185,6 +185,37 @@ class RequestPrefetch final : public RequestBase { framework::Scope* local_scope_; }; +class RequestCheckpointNotify final : public RequestBase { + public: + explicit RequestCheckpointNotify(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), + responder_(&ctx_), + local_scope_(nullptr) { + request_.reset(new VariableResponse(request_handler->scope(), + request_handler->dev_ctx(), true)); + int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); + } + + virtual ~RequestCheckpointNotify() {} + + std::string GetReqName() override { return request_->Varname(); } + + void Process() override { + auto scope = request_->GetMutableLocalScope(); + std::string nullptr_str = nullptr; + framework::Variable* invar = nullptr; + framework::Variable* outvar = nullptr; + + request_handler_->Handle(nullptr_str, scope, invar, &outvar, nullptr_str); + Finish(reply_, &responder_); + } +} + void AsyncGRPCServer::WaitServerReady() { VLOG(3) << "AsyncGRPCServer is wait server ready"; std::unique_lock lock(this->mutex_ready_); @@ -288,6 +319,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, b = new RequestGet(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); + } else if (rpc_name == kRequestCheckpoint) { + b = new RequestCheckpoin } else { PADDLE_ENFORCE(false, "not supported rpc"); } -- GitLab