grpc_server.cc 19.5 KB
Newer Older
1
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <limits>
#include <string>
G
gongweibao 已提交
17

W
Wu Yi 已提交
18 19
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_server.h"
G
gongweibao 已提交
20

21
using ::grpc::ServerAsyncResponseWriter;
X
Xin Pan 已提交
22

23 24
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28 29 30 31 32 33
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
34
  explicit RequestBase(GrpcService::AsyncService* service,
35 36
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
Q
qiaolongfei 已提交
37 38 39
      : service_(service),
        cq_(cq),
        status_(PROCESS),
40 41
        request_handler_(request_handler),
        req_id_(req_id) {
G
gongweibao 已提交
42 43
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
44
  virtual ~RequestBase() {}
45
  virtual void Process() = 0;
G
gongweibao 已提交
46

G
gongweibao 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
  std::string Status2String(const std::string& method) {
    std::string status = "Process";
    if (status_ == FINISH) {
      status = "Finish";
    }

    std::ostringstream s;
    s << method << " name:[" << GetReqName() << "]"
      << ", ep:[" << ctx_.peer() << "]"
      << " " << status << " using req_id:" << req_id_;
    return s.str();
  }

X
Xin Pan 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  CallStatus Status() const {
    std::lock_guard<std::mutex> l(status_mu_);
    return status_;
  }

  template <typename T>
  void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
    std::lock_guard<std::mutex> l(status_mu_);
    status_ = FINISH;
    responder->Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
  }
72
  virtual std::string GetReqName() = 0;
G
gongweibao 已提交
73 74

 protected:
X
Xin Pan 已提交
75
  mutable std::mutex status_mu_;
76 77 78
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
79
  CallStatus status_;
80 81
  RequestHandler* request_handler_;
  int req_id_;
G
gongweibao 已提交
82 83 84 85
};

class RequestSend final : public RequestBase {
 public:
86
  explicit RequestSend(GrpcService::AsyncService* service,
87 88 89
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
90 91 92
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(),
                                            !request_handler->sync_mode()));
93
    int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
94 95
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
96
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
97 98
  }
  virtual ~RequestSend() {}
99 100 101 102
  std::string GetReqName() override { return request_->Varname(); }

  void Process() override {
    std::string varname = GetReqName();
M
minqiyang 已提交
103
    VLOG(4) << "RequestSend var_name:" << varname;
G
gongweibao 已提交
104

105 106
    auto scope = request_->GetMutableLocalScope();
    auto invar = request_->GetVar();
W
Wu Yi 已提交
107
    int trainer_id = request_->GetTrainerId();
108
    framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
109
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
X
Xin Pan 已提交
110
    Finish(reply_, &responder_);
G
gongweibao 已提交
111 112 113
  }

 protected:
X
Xin Pan 已提交
114
  sendrecv::VoidMessage reply_;
115
  std::shared_ptr<GRPCVariableResponse> request_;
G
gongweibao 已提交
116 117 118 119 120
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
121
  explicit RequestGet(GrpcService::AsyncService* service,
122 123 124
                      ::grpc::ServerCompletionQueue* cq,
                      RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
125
    auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
126 127
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
128
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
129 130 131 132
  }

  virtual ~RequestGet() {}

133
  std::string GetReqName() override { return request_.varname(); }
G
gongweibao 已提交
134

135
  void Process() override {
G
gongweibao 已提交
136
    // proc request.
137
    std::string varname = request_.varname();
138
    std::string out_varname = request_.out_varname();
W
Wu Yi 已提交
139
    int trainer_id = request_.trainer_id();
140 141

    VLOG(4) << "RequestGet " << out_varname << " from " << varname;
142 143

    auto scope = request_handler_->scope();
144
    framework::Variable* invar = nullptr;
145
    framework::Variable* outvar = nullptr;
146

Q
Qiao Longfei 已提交
147 148
    auto* tmp_scope = scope->NewTmpScope();
    request_handler_->Handle(varname, tmp_scope, invar, &outvar, trainer_id,
149
                             out_varname);
150 151

    if (outvar) {
152 153 154
      SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
                            &reply_);
    }
Q
Qiao Longfei 已提交
155
    delete tmp_scope;
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    Finish(reply_, &responder_);
  }

 protected:
  sendrecv::VariableMessage request_;
  ::grpc::ByteBuffer reply_;
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};

class RequestGetNoBarrier final : public RequestBase {
 public:
  explicit RequestGetNoBarrier(GrpcService::AsyncService* service,
                               ::grpc::ServerCompletionQueue* cq,
                               RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
    auto method_id =
        static_cast<int>(distributed::GrpcMethod::kGetVariableNoBarrier);
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestGetNoBarrier() {}

  std::string GetReqName() override { return request_.varname(); }

  void Process() override {
    // proc request.
    std::string varname = request_.varname();
    std::string out_varname = request_.out_varname();
    int trainer_id = request_.trainer_id();

    VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname;

    auto scope = request_handler_->scope();
    framework::Variable* invar = nullptr;
    framework::Variable* outvar = nullptr;

    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
                             out_varname);

    if (outvar) {
      SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
199
                            &reply_);
200
    }
X
Xin Pan 已提交
201
    Finish(reply_, &responder_);
G
gongweibao 已提交
202 203 204 205
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
206
  ::grpc::ByteBuffer reply_;
207
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
208 209
};

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
class RequestGetMonomerVariable final : public RequestBase {
 public:
  explicit RequestGetMonomerVariable(GrpcService::AsyncService* service,
                                     ::grpc::ServerCompletionQueue* cq,
                                     RequestHandler* request_handler,
                                     int req_id, RPCServer* rpc_server)
      : RequestBase(service, cq, request_handler, req_id),
        responder_(&ctx_),
        rpc_server_(rpc_server) {
    auto method_id =
        static_cast<int>(distributed::GrpcMethod::kGetMonomerVariable);
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestGetMonomerVariable() {}

  std::string GetReqName() override { return request_.varname(); }

  void Process() override {
    // proc request.
    std::string varname = request_.varname();

    rpc_server_->WaitVarCond(varname);
    MonomerHandle h = rpc_server_->GetMonomer(varname);

    auto scope = h.scope_;
    auto invar = scope->FindVar(varname);
    framework::Variable* outvar = nullptr;

    request_handler_->Handle(varname, scope, invar, &outvar,
                             request_.trainer_id());

    if (outvar) {
      SerializeToByteBuffer(varname, outvar, *h.dev_ctx_, &reply_);
    }
    Finish(reply_, &responder_);
  }

 protected:
  sendrecv::VariableMessage request_;
  ::grpc::ByteBuffer reply_;
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  RPCServer* rpc_server_{nullptr};
};

class RequestGetMonomerBarrier final : public RequestBase {
 public:
  explicit RequestGetMonomerBarrier(GrpcService::AsyncService* service,
                                    ::grpc::ServerCompletionQueue* cq,
                                    RequestHandler* request_handler, int req_id,
                                    RPCServer* rpc_server)
      : RequestBase(service, cq, request_handler, req_id),
        responder_(&ctx_),
        rpc_server_(rpc_server) {
    auto method_id =
        static_cast<int>(distributed::GrpcMethod::kGetMonomerBarrier);
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestGetMonomerBarrier() {}

  std::string GetReqName() override { return request_.varname(); }

  void Process() override {
    // proc request.
    std::string varname = request_.varname();
    VLOG(4) << "RequestGetMonomerBarrier " << varname;

    rpc_server_->WaitVarCond(varname);
    MonomerHandle h = rpc_server_->GetMonomer(varname);

    framework::Scope* scope = nullptr;
    framework::Variable* invar = nullptr;
    framework::Variable* outvar = nullptr;

    request_handler_->Handle(varname, scope, invar, &outvar,
                             request_.trainer_id());

    Finish(reply_, &responder_);
  }

 protected:
  sendrecv::VariableMessage request_;
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
  RPCServer* rpc_server_{nullptr};
};

302 303 304
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
305 306 307
                           ::grpc::ServerCompletionQueue* cq,
                           RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id),
308
        responder_(&ctx_),
309
        local_scope_(nullptr) {
310 311
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(), true));
312 313
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
314 315
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
316
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
317 318 319 320
  }

  virtual ~RequestPrefetch() {}

321
  std::string GetReqName() override { return request_->Varname(); }
322

323
  void Process() override {
324
    // prefetch process...
325 326
    std::string in_var_name = request_->Varname();
    std::string out_var_name = request_->OutVarname();
Q
Qiao Longfei 已提交
327
    std::string table_name = request_->TableName();
W
Wu Yi 已提交
328
    int trainer_id = request_->GetTrainerId();
M
minqiyang 已提交
329 330
    VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
            << " out_var_name: " << out_var_name;
331 332

    auto scope = request_->GetMutableLocalScope();
333
    auto invar = scope->FindVar(in_var_name);
334
    // out var must be created in local scope!
Q
qiaolongfei 已提交
335
    framework::Variable* outvar = scope->Var(out_var_name);
336

W
Wu Yi 已提交
337
    request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
Q
can run  
Qiao Longfei 已提交
338
                             out_var_name, table_name);
Y
Yancey1989 已提交
339

340
    SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
341
                          &reply_);
X
Xin Pan 已提交
342
    Finish(reply_, &responder_);
343 344 345
  }

 protected:
346
  std::shared_ptr<GRPCVariableResponse> request_;
X
Xin Pan 已提交
347
  ::grpc::ByteBuffer reply_;
348
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
349
  framework::Scope* local_scope_;
350 351
};

T
tangwei12 已提交
352 353 354 355 356
class RequestCheckpointNotify final : public RequestBase {
 public:
  explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
                                   ::grpc::ServerCompletionQueue* cq,
                                   RequestHandler* request_handler, int req_id)
T
tangwei12 已提交
357
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
358 359
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx()));
T
tangwei12 已提交
360 361
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
T
tangwei12 已提交
362 363 364 365 366 367 368
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestCheckpointNotify() {}

369
  std::string GetReqName() override { return request_->Varname(); }
T
tangwei12 已提交
370 371 372

  void Process() override {
    auto scope = request_->GetMutableLocalScope();
373 374

    std::string checkpoint_notify = request_->Varname();
T
tangwei12 已提交
375
    std::string checkpoint_dir = request_->OutVarname();
W
Wu Yi 已提交
376
    int trainer_id = request_->GetTrainerId();
377

M
minqiyang 已提交
378 379
    VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
            << ", dir: " << checkpoint_dir;
T
tangwei12 已提交
380

T
tangwei12 已提交
381
    request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
W
Wu Yi 已提交
382
                             trainer_id, checkpoint_dir);
T
tangwei12 已提交
383 384
    Finish(reply_, &responder_);
  }
T
tangwei12 已提交
385 386

 protected:
387
  std::shared_ptr<GRPCVariableResponse> request_;
T
tangwei12 已提交
388 389
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
T
tangwei12 已提交
390
};
T
tangwei12 已提交
391

T
done  
typhoonzero 已提交
392
void AsyncGRPCServer::WaitServerReady() {
393
  VLOG(4) << "AsyncGRPCServer is waiting server ready";
T
update  
typhoonzero 已提交
394
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
395
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
M
minqiyang 已提交
396
  VLOG(4) << "AsyncGRPCServer WaitSeverReady";
T
update  
typhoonzero 已提交
397 398
}

399 400 401 402 403 404 405 406 407 408 409 410 411 412
// Define an option subclass in order to disable SO_REUSEPORT for the
// server socket.
// Come from:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
class NoReusePortOption : public ::grpc::ServerBuilderOption {
 public:
  void UpdateArguments(::grpc::ChannelArguments* args) override {
    args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }

  void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
                         plugins) override {}
};

413
void AsyncGRPCServer::StartServer() {
414
  ::grpc::ServerBuilder builder;
415
  builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
T
typhoonzero 已提交
416
                           &selected_port_);
417

G
gongweibao 已提交
418 419
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
420 421 422 423
  if (FLAGS_rpc_disable_reuse_port) {
    builder.SetOption(
        std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
  }
G
gongweibao 已提交
424 425
  builder.RegisterService(&service_);

426 427 428
  for (auto t : rpc_call_map_) {
    rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
  }
Y
Yancey 已提交
429

G
gongweibao 已提交
430
  server_ = builder.BuildAndStart();
431
  LOG(INFO) << "Server listening on " << bind_address_
T
typhoonzero 已提交
432
            << " selected port: " << selected_port_;
G
gongweibao 已提交
433

434 435 436
  std::function<void(const std::string&, int)> f =
      std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
                std::placeholders::_1, std::placeholders::_2);
X
Xin Pan 已提交
437

438 439 440 441 442
  for (auto& t : rpc_call_map_) {
    auto& rpc_name = t.first;
    auto& cq = rpc_cq_[rpc_name];
    auto threadnum = rpc_thread_num_[rpc_name];
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
443

444 445 446
    reqs.reserve(kRequestBufSize);

    for (int i = 0; i < kRequestBufSize; i++) {
M
minqiyang 已提交
447
      VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
448 449 450 451 452 453
      TryToRegisterNewOne(rpc_name, i);
    }

    for (int i = 0; i < threadnum; i++) {
      rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
          &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
M
minqiyang 已提交
454
      VLOG(4) << t.first << " creates threads!";
455
    }
X
Xin Pan 已提交
456
  }
457

T
wip  
typhoonzero 已提交
458 459 460 461 462
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
463

G
gongweibao 已提交
464 465
  // wait server
  server_->Wait();
466 467 468 469 470

  for (auto& t : rpc_threads_) {
    auto& threads = t.second;
    for (size_t i = 0; i < threads.size(); ++i) {
      threads[i]->join();
M
minqiyang 已提交
471
      VLOG(4) << t.first << " threads ends!";
472
    }
X
Xin Pan 已提交
473
  }
G
gongweibao 已提交
474 475 476
}

void AsyncGRPCServer::ShutdownQueue() {
477 478
  for (auto& t : rpc_cq_) {
    t.second->Shutdown();
M
minqiyang 已提交
479
    VLOG(4) << t.first << " queue shutdown!";
480
  }
G
gongweibao 已提交
481 482
}

483 484
void AsyncGRPCServer::ShutDownImpl() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
T
typhoonzero 已提交
485
  is_shut_down_ = true;
G
gongweibao 已提交
486
  ShutdownQueue();
487

M
minqiyang 已提交
488
  VLOG(4) << "server_ shutdown!";
T
typhoonzero 已提交
489
  server_->Shutdown();
G
gongweibao 已提交
490 491
}

492 493
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
                                          int req_id) {
G
gongweibao 已提交
494 495
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
M
minqiyang 已提交
496
    VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
497 498 499
    return;
  }

M
minqiyang 已提交
500 501
  VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
          << " REQ ID: " << req_id;
T
tangwei12 已提交
502

503 504 505 506 507 508 509 510 511
  auto& reqs = rpc_reqs_[rpc_name];
  auto& handler = rpc_call_map_[rpc_name];
  auto& cq = rpc_cq_[rpc_name];

  RequestBase* b = nullptr;
  if (rpc_name == kRequestSend) {
    b = new RequestSend(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestGet) {
    b = new RequestGet(&service_, cq.get(), handler, req_id);
512 513 514

  } else if (rpc_name == kRequestGetNoBarrier) {
    b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
515 516 517 518 519 520
  } else if (rpc_name == kRequestGetMonomerVariable) {
    b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
                                      this);
  } else if (rpc_name == kRequestGetMonomerBarrier) {
    b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id,
                                     this);
521 522
  } else if (rpc_name == kRequestPrefetch) {
    b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
T
tangwei12 已提交
523
  } else if (rpc_name == kRequestCheckpoint) {
T
tangwei12 已提交
524
    b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
525
  } else {
Q
qiaolongfei 已提交
526
    PADDLE_ENFORCE(false, "not supported rpc");
G
gongweibao 已提交
527 528
  }

529
  reqs[req_id] = b;
530

531
  VLOG(4) << "TryToRegisterNewOne status:" << b->Status();
532 533
}

X
Xin Pan 已提交
534
void AsyncGRPCServer::HandleRequest(
535 536
    ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
    std::function<void(const std::string&, int)> TryToRegisterNewOne) {
G
gongweibao 已提交
537 538
  void* tag = NULL;
  bool ok = false;
539

G
gongweibao 已提交
540
  while (true) {
M
minqiyang 已提交
541
    VLOG(4) << "HandleRequest " << rpc_name << " wait next";
G
gongweibao 已提交
542
    if (!cq->Next(&tag, &ok)) {
G
gongweibao 已提交
543
      LOG(WARNING) << "CompletionQueue " << rpc_name << " shutdown!";
G
gongweibao 已提交
544 545
      break;
    }
Q
qiaolongfei 已提交
546

547
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
M
minqiyang 已提交
548 549
    VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
            << " get next";
G
gongweibao 已提交
550

551
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
552 553
    RequestBase* base = nullptr;
    {
554 555 556
      PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
      std::unique_lock<std::mutex> lock(cq_mutex_);
      base = reqs[req_id];
X
Xin Pan 已提交
557
    }
558

M
minqiyang 已提交
559
    VLOG(3) << base->Status2String(rpc_name);
G
gongweibao 已提交
560

G
gongweibao 已提交
561 562 563 564
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
565
    if (!ok) {
G
gongweibao 已提交
566 567
      VLOG(4) << "completion queue:" << rpc_name << " recv no regular event"
              << " context:" << base->Status2String(rpc_name);
568
      TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
569 570 571 572 573 574 575 576 577 578
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
        break;
      }
      case FINISH: {
579
        TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
580 581 582 583 584 585 586 587
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

588
}  // namespace distributed
G
gongweibao 已提交
589 590
}  // namespace operators
}  // namespace paddle