grpc_server.cc 14.3 KB
Newer Older
1
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <limits>
#include <string>
G
gongweibao 已提交
17

18
#include "paddle/fluid/operators/distributed/grpc_serde.h"
19
#include "paddle/fluid/operators/distributed/grpc_server.h"
G
gongweibao 已提交
20

21
using ::grpc::ServerAsyncResponseWriter;
X
Xin Pan 已提交
22

23 24
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28 29 30 31 32 33
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
34
  explicit RequestBase(GrpcService::AsyncService* service,
35 36
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
Q
qiaolongfei 已提交
37 38 39
      : service_(service),
        cq_(cq),
        status_(PROCESS),
40 41
        request_handler_(request_handler),
        req_id_(req_id) {
G
gongweibao 已提交
42 43
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
44
  virtual ~RequestBase() {}
45
  virtual void Process() = 0;
G
gongweibao 已提交
46

G
gongweibao 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
  std::string Status2String(const std::string& method) {
    std::string status = "Process";
    if (status_ == FINISH) {
      status = "Finish";
    }

    std::ostringstream s;
    s << method << " name:[" << GetReqName() << "]"
      << ", ep:[" << ctx_.peer() << "]"
      << " " << status << " using req_id:" << req_id_;
    return s.str();
  }

X
Xin Pan 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  CallStatus Status() const {
    std::lock_guard<std::mutex> l(status_mu_);
    return status_;
  }

  template <typename T>
  void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
    std::lock_guard<std::mutex> l(status_mu_);
    status_ = FINISH;
    responder->Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
  }
72
  virtual std::string GetReqName() = 0;
G
gongweibao 已提交
73 74

 protected:
X
Xin Pan 已提交
75
  mutable std::mutex status_mu_;
76 77 78
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
79
  CallStatus status_;
80 81
  RequestHandler* request_handler_;
  int req_id_;
G
gongweibao 已提交
82 83 84 85
};

class RequestSend final : public RequestBase {
 public:
86
  explicit RequestSend(GrpcService::AsyncService* service,
87 88 89
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
90 91 92
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(),
                                            !request_handler->sync_mode()));
93
    int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
94 95
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
96
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
97 98
  }
  virtual ~RequestSend() {}
99 100 101 102
  std::string GetReqName() override { return request_->Varname(); }

  void Process() override {
    std::string varname = GetReqName();
M
minqiyang 已提交
103
    VLOG(4) << "RequestSend var_name:" << varname;
G
gongweibao 已提交
104

105 106
    auto scope = request_->GetMutableLocalScope();
    auto invar = request_->GetVar();
W
Wu Yi 已提交
107
    int trainer_id = request_->GetTrainerId();
108 109
    framework::Variable* outvar = nullptr;

W
Wu Yi 已提交
110
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
X
Xin Pan 已提交
111
    Finish(reply_, &responder_);
G
gongweibao 已提交
112 113 114
  }

 protected:
X
Xin Pan 已提交
115
  sendrecv::VoidMessage reply_;
116
  std::shared_ptr<GRPCVariableResponse> request_;
G
gongweibao 已提交
117 118 119 120 121
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
122
  explicit RequestGet(GrpcService::AsyncService* service,
123 124 125
                      ::grpc::ServerCompletionQueue* cq,
                      RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
126
    auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
127 128
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
129
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
130 131 132 133
  }

  virtual ~RequestGet() {}

134
  std::string GetReqName() override { return request_.varname(); }
G
gongweibao 已提交
135

136
  void Process() override {
G
gongweibao 已提交
137
    // proc request.
138
    std::string varname = request_.varname();
W
Wu Yi 已提交
139
    int trainer_id = request_.trainer_id();
M
minqiyang 已提交
140
    VLOG(4) << "RequestGet " << varname;
141 142 143 144

    auto scope = request_handler_->scope();
    auto invar = scope->FindVar(varname);
    framework::Variable* outvar = nullptr;
145

W
Wu Yi 已提交
146
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
147 148 149 150

    if (outvar) {
      SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
                            &reply_);
151
    }
X
Xin Pan 已提交
152
    Finish(reply_, &responder_);
G
gongweibao 已提交
153 154 155 156
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
157
  ::grpc::ByteBuffer reply_;
158
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
159 160
};

161 162 163
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
164 165 166
                           ::grpc::ServerCompletionQueue* cq,
                           RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id),
167
        responder_(&ctx_),
168
        local_scope_(nullptr) {
169 170
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(), true));
171 172
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
173 174
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
175
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
176 177 178 179
  }

  virtual ~RequestPrefetch() {}

180
  std::string GetReqName() override { return request_->Varname(); }
181

182
  void Process() override {
183
    // prefetch process...
184 185
    std::string in_var_name = request_->Varname();
    std::string out_var_name = request_->OutVarname();
W
Wu Yi 已提交
186
    int trainer_id = request_->GetTrainerId();
M
minqiyang 已提交
187 188
    VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
            << " out_var_name: " << out_var_name;
189 190

    auto scope = request_->GetMutableLocalScope();
191
    auto invar = scope->FindVar(in_var_name);
192
    // out var must be created in local scope!
Q
qiaolongfei 已提交
193
    framework::Variable* outvar = scope->Var(out_var_name);
194

W
Wu Yi 已提交
195 196
    request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
                             out_var_name);
Y
Yancey1989 已提交
197

198
    SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
199
                          &reply_);
X
Xin Pan 已提交
200
    Finish(reply_, &responder_);
201 202 203
  }

 protected:
204
  std::shared_ptr<GRPCVariableResponse> request_;
X
Xin Pan 已提交
205
  ::grpc::ByteBuffer reply_;
206
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
207
  framework::Scope* local_scope_;
208 209
};

T
tangwei12 已提交
210 211 212 213 214
class RequestCheckpointNotify final : public RequestBase {
 public:
  explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
                                   ::grpc::ServerCompletionQueue* cq,
                                   RequestHandler* request_handler, int req_id)
T
tangwei12 已提交
215
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
216 217
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx()));
T
tangwei12 已提交
218 219
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
T
tangwei12 已提交
220 221 222 223 224 225 226
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestCheckpointNotify() {}

227
  std::string GetReqName() override { return request_->Varname(); }
T
tangwei12 已提交
228 229 230

  void Process() override {
    auto scope = request_->GetMutableLocalScope();
231 232

    std::string checkpoint_notify = request_->Varname();
T
tangwei12 已提交
233
    std::string checkpoint_dir = request_->OutVarname();
W
Wu Yi 已提交
234
    int trainer_id = request_->GetTrainerId();
235

M
minqiyang 已提交
236 237
    VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
            << ", dir: " << checkpoint_dir;
T
tangwei12 已提交
238

T
tangwei12 已提交
239
    request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
W
Wu Yi 已提交
240
                             trainer_id, checkpoint_dir);
T
tangwei12 已提交
241 242
    Finish(reply_, &responder_);
  }
T
tangwei12 已提交
243 244

 protected:
245
  std::shared_ptr<GRPCVariableResponse> request_;
T
tangwei12 已提交
246 247
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
T
tangwei12 已提交
248
};
T
tangwei12 已提交
249

T
done  
typhoonzero 已提交
250
void AsyncGRPCServer::WaitServerReady() {
M
minqiyang 已提交
251
  VLOG(4) << "AsyncGRPCServer is wait server ready";
T
update  
typhoonzero 已提交
252
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
253
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
M
minqiyang 已提交
254
  VLOG(4) << "AsyncGRPCServer WaitSeverReady";
T
update  
typhoonzero 已提交
255 256
}

257 258 259 260 261 262 263 264 265 266 267 268 269 270
// Define an option subclass in order to disable SO_REUSEPORT for the
// server socket.
// Come from:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
class NoReusePortOption : public ::grpc::ServerBuilderOption {
 public:
  void UpdateArguments(::grpc::ChannelArguments* args) override {
    args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }

  void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
                         plugins) override {}
};

271
void AsyncGRPCServer::StartServer() {
272
  ::grpc::ServerBuilder builder;
273
  builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
T
typhoonzero 已提交
274
                           &selected_port_);
275

G
gongweibao 已提交
276 277
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
278 279 280 281
  if (FLAGS_rpc_disable_reuse_port) {
    builder.SetOption(
        std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
  }
G
gongweibao 已提交
282 283
  builder.RegisterService(&service_);

284 285 286
  for (auto t : rpc_call_map_) {
    rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
  }
Y
Yancey 已提交
287

G
gongweibao 已提交
288
  server_ = builder.BuildAndStart();
289
  LOG(INFO) << "Server listening on " << bind_address_
T
typhoonzero 已提交
290
            << " selected port: " << selected_port_;
G
gongweibao 已提交
291

292 293 294
  std::function<void(const std::string&, int)> f =
      std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
                std::placeholders::_1, std::placeholders::_2);
X
Xin Pan 已提交
295

296 297 298 299 300
  for (auto& t : rpc_call_map_) {
    auto& rpc_name = t.first;
    auto& cq = rpc_cq_[rpc_name];
    auto threadnum = rpc_thread_num_[rpc_name];
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
301

302 303 304
    reqs.reserve(kRequestBufSize);

    for (int i = 0; i < kRequestBufSize; i++) {
M
minqiyang 已提交
305
      VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
306 307 308 309 310 311
      TryToRegisterNewOne(rpc_name, i);
    }

    for (int i = 0; i < threadnum; i++) {
      rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
          &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
M
minqiyang 已提交
312
      VLOG(4) << t.first << " creates threads!";
313
    }
X
Xin Pan 已提交
314
  }
315

T
wip  
typhoonzero 已提交
316 317 318 319 320
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
321

G
gongweibao 已提交
322 323
  // wait server
  server_->Wait();
324 325 326 327 328

  for (auto& t : rpc_threads_) {
    auto& threads = t.second;
    for (size_t i = 0; i < threads.size(); ++i) {
      threads[i]->join();
M
minqiyang 已提交
329
      VLOG(4) << t.first << " threads ends!";
330
    }
X
Xin Pan 已提交
331
  }
G
gongweibao 已提交
332 333 334
}

void AsyncGRPCServer::ShutdownQueue() {
335 336
  for (auto& t : rpc_cq_) {
    t.second->Shutdown();
M
minqiyang 已提交
337
    VLOG(4) << t.first << " queue shutdown!";
338
  }
G
gongweibao 已提交
339 340
}

341 342
void AsyncGRPCServer::ShutDownImpl() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
T
typhoonzero 已提交
343
  is_shut_down_ = true;
G
gongweibao 已提交
344
  ShutdownQueue();
345

M
minqiyang 已提交
346
  VLOG(4) << "server_ shutdown!";
T
typhoonzero 已提交
347
  server_->Shutdown();
G
gongweibao 已提交
348 349
}

350 351
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
                                          int req_id) {
G
gongweibao 已提交
352 353
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
M
minqiyang 已提交
354
    VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
355 356 357
    return;
  }

M
minqiyang 已提交
358 359
  VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
          << " REQ ID: " << req_id;
T
tangwei12 已提交
360

361 362 363 364 365 366 367 368 369 370 371
  auto& reqs = rpc_reqs_[rpc_name];
  auto& handler = rpc_call_map_[rpc_name];
  auto& cq = rpc_cq_[rpc_name];

  RequestBase* b = nullptr;
  if (rpc_name == kRequestSend) {
    b = new RequestSend(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestGet) {
    b = new RequestGet(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestPrefetch) {
    b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
T
tangwei12 已提交
372
  } else if (rpc_name == kRequestCheckpoint) {
T
tangwei12 已提交
373
    b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
374
  } else {
Q
qiaolongfei 已提交
375
    PADDLE_ENFORCE(false, "not supported rpc");
G
gongweibao 已提交
376 377
  }

378
  reqs[req_id] = b;
379

M
minqiyang 已提交
380
  VLOG(4) << "Create RequestSend status:" << b->Status();
381 382
}

X
Xin Pan 已提交
383
void AsyncGRPCServer::HandleRequest(
384 385
    ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
    std::function<void(const std::string&, int)> TryToRegisterNewOne) {
G
gongweibao 已提交
386 387
  void* tag = NULL;
  bool ok = false;
388

G
gongweibao 已提交
389
  while (true) {
M
minqiyang 已提交
390
    VLOG(4) << "HandleRequest " << rpc_name << " wait next";
G
gongweibao 已提交
391
    if (!cq->Next(&tag, &ok)) {
M
minqiyang 已提交
392
      VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
G
gongweibao 已提交
393 394
      break;
    }
Q
qiaolongfei 已提交
395

396
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
M
minqiyang 已提交
397 398
    VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
            << " get next";
G
gongweibao 已提交
399

400
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
401 402
    RequestBase* base = nullptr;
    {
403 404 405
      PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
      std::unique_lock<std::mutex> lock(cq_mutex_);
      base = reqs[req_id];
X
Xin Pan 已提交
406
    }
407

M
minqiyang 已提交
408
    VLOG(3) << base->Status2String(rpc_name);
G
gongweibao 已提交
409

G
gongweibao 已提交
410 411 412 413
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
414
    if (!ok) {
415
      LOG(WARNING) << "completion queue:" << rpc_name
G
gongweibao 已提交
416 417
                   << " recv no regular event"
                   << " context:" << base->Status2String(rpc_name);
418
      TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
419 420 421 422 423 424 425 426 427 428
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
        break;
      }
      case FINISH: {
429
        TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
430 431 432 433 434 435 436 437
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

438
}  // namespace distributed
G
gongweibao 已提交
439 440
}  // namespace operators
}  // namespace paddle