提交 925e2324 编写于 作者: T tangwei12

add RequestCheckpointNotify in grpc

上级 985026ce
...@@ -185,6 +185,37 @@ class RequestPrefetch final : public RequestBase { ...@@ -185,6 +185,37 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope_; framework::Scope* local_scope_;
}; };
class RequestCheckpointNotify final : public RequestBase {
public:
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id),
responder_(&ctx_),
local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestCheckpointNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string nullptr_str = nullptr;
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(nullptr_str, scope, invar, &outvar, nullptr_str);
Finish(reply_, &responder_);
}
}
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready"; VLOG(3) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
...@@ -288,6 +319,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -288,6 +319,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestGet(&service_, cq.get(), handler, req_id); b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) { } else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpoin
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册