grpc_server.cc 14.3 KB
Newer Older
1
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <limits>
#include <string>
G
gongweibao 已提交
17

18
#include "paddle/fluid/operators/distributed/grpc_serde.h"
19
#include "paddle/fluid/operators/distributed/grpc_server.h"
G
gongweibao 已提交
20

21
using ::grpc::ServerAsyncResponseWriter;
X
Xin Pan 已提交
22

23 24
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28 29 30 31 32 33
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
34
  explicit RequestBase(GrpcService::AsyncService* service,
35 36
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
Q
qiaolongfei 已提交
37 38 39
      : service_(service),
        cq_(cq),
        status_(PROCESS),
40 41
        request_handler_(request_handler),
        req_id_(req_id) {
G
gongweibao 已提交
42 43
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
44
  virtual ~RequestBase() {}
45
  virtual void Process() = 0;
G
gongweibao 已提交
46

G
gongweibao 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
  std::string Status2String(const std::string& method) {
    std::string status = "Process";
    if (status_ == FINISH) {
      status = "Finish";
    }

    std::ostringstream s;
    s << method << " name:[" << GetReqName() << "]"
      << ", ep:[" << ctx_.peer() << "]"
      << " " << status << " using req_id:" << req_id_;
    return s.str();
  }

X
Xin Pan 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  CallStatus Status() const {
    std::lock_guard<std::mutex> l(status_mu_);
    return status_;
  }

  template <typename T>
  void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
    std::lock_guard<std::mutex> l(status_mu_);
    status_ = FINISH;
    responder->Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
  }
72
  virtual std::string GetReqName() = 0;
G
gongweibao 已提交
73 74

 protected:
X
Xin Pan 已提交
75
  mutable std::mutex status_mu_;
76 77 78
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
79
  CallStatus status_;
80 81
  RequestHandler* request_handler_;
  int req_id_;
G
gongweibao 已提交
82 83 84 85
};

class RequestSend final : public RequestBase {
 public:
86
  explicit RequestSend(GrpcService::AsyncService* service,
87 88 89
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
90 91 92
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(),
                                            !request_handler->sync_mode()));
93
    int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
94 95
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
96
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
97 98
  }
  virtual ~RequestSend() {}
99 100 101 102
  std::string GetReqName() override { return request_->Varname(); }

  void Process() override {
    std::string varname = GetReqName();
103
    VLOG(40) << "RequestSend var_name:" << varname;
G
gongweibao 已提交
104

105 106
    auto scope = request_->GetMutableLocalScope();
    auto invar = request_->GetVar();
W
Wu Yi 已提交
107
    int trainer_id = request_->GetTrainerId();
108 109
    framework::Variable* outvar = nullptr;

W
Wu Yi 已提交
110
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
X
Xin Pan 已提交
111
    Finish(reply_, &responder_);
G
gongweibao 已提交
112 113 114
  }

 protected:
X
Xin Pan 已提交
115
  sendrecv::VoidMessage reply_;
116
  std::shared_ptr<GRPCVariableResponse> request_;
G
gongweibao 已提交
117 118 119 120 121
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
122
  explicit RequestGet(GrpcService::AsyncService* service,
123 124 125
                      ::grpc::ServerCompletionQueue* cq,
                      RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
126
    auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
127 128
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
129
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
130 131 132 133
  }

  virtual ~RequestGet() {}

134
  std::string GetReqName() override { return request_.varname(); }
G
gongweibao 已提交
135

136
  void Process() override {
G
gongweibao 已提交
137
    // proc request.
138
    std::string varname = request_.varname();
W
Wu Yi 已提交
139
    int trainer_id = request_.trainer_id();
140
    VLOG(40) << "RequestGet " << varname;
141 142 143 144

    auto scope = request_handler_->scope();
    auto invar = scope->FindVar(varname);
    framework::Variable* outvar = nullptr;
145

W
Wu Yi 已提交
146
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
147 148 149 150

    if (outvar) {
      SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
                            &reply_);
151
    }
X
Xin Pan 已提交
152
    Finish(reply_, &responder_);
G
gongweibao 已提交
153 154 155 156
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
157
  ::grpc::ByteBuffer reply_;
158
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
159 160
};

161 162 163
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
164 165 166
                           ::grpc::ServerCompletionQueue* cq,
                           RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id),
167
        responder_(&ctx_),
168
        local_scope_(nullptr) {
169 170
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(), true));
171 172
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
173 174
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
175
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
176 177 178 179
  }

  virtual ~RequestPrefetch() {}

180
  std::string GetReqName() override { return request_->Varname(); }
181

182
  void Process() override {
183
    // prefetch process...
184 185
    std::string in_var_name = request_->Varname();
    std::string out_var_name = request_->OutVarname();
W
Wu Yi 已提交
186
    int trainer_id = request_->GetTrainerId();
187 188
    VLOG(40) << "RequestPrefetch, in_var_name: " << in_var_name
             << " out_var_name: " << out_var_name;
189 190

    auto scope = request_->GetMutableLocalScope();
191
    auto invar = scope->FindVar(in_var_name);
192
    // out var must be created in local scope!
Q
qiaolongfei 已提交
193
    framework::Variable* outvar = scope->Var(out_var_name);
194

W
Wu Yi 已提交
195 196
    request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
                             out_var_name);
Y
Yancey1989 已提交
197

198
    SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
199
                          &reply_);
X
Xin Pan 已提交
200
    Finish(reply_, &responder_);
201 202 203
  }

 protected:
204
  std::shared_ptr<GRPCVariableResponse> request_;
X
Xin Pan 已提交
205
  ::grpc::ByteBuffer reply_;
206
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
207
  framework::Scope* local_scope_;
208 209
};

T
tangwei12 已提交
210 211 212 213 214
class RequestCheckpointNotify final : public RequestBase {
 public:
  explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
                                   ::grpc::ServerCompletionQueue* cq,
                                   RequestHandler* request_handler, int req_id)
T
tangwei12 已提交
215
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
216 217
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx()));
T
tangwei12 已提交
218 219
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
T
tangwei12 已提交
220 221 222 223 224 225 226
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestCheckpointNotify() {}

227
  std::string GetReqName() override { return request_->Varname(); }
T
tangwei12 已提交
228 229 230

  void Process() override {
    auto scope = request_->GetMutableLocalScope();
231 232

    std::string checkpoint_notify = request_->Varname();
T
tangwei12 已提交
233
    std::string checkpoint_dir = request_->OutVarname();
W
Wu Yi 已提交
234
    int trainer_id = request_->GetTrainerId();
235

236 237
    VLOG(40) << "RequestCheckpointNotify notify: " << checkpoint_notify
             << ", dir: " << checkpoint_dir;
T
tangwei12 已提交
238

T
tangwei12 已提交
239
    request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
W
Wu Yi 已提交
240
                             trainer_id, checkpoint_dir);
T
tangwei12 已提交
241 242
    Finish(reply_, &responder_);
  }
T
tangwei12 已提交
243 244

 protected:
245
  std::shared_ptr<GRPCVariableResponse> request_;
T
tangwei12 已提交
246 247
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
T
tangwei12 已提交
248
};
T
tangwei12 已提交
249

T
done  
typhoonzero 已提交
250
void AsyncGRPCServer::WaitServerReady() {
251
  VLOG(40) << "AsyncGRPCServer is wait server ready";
T
update  
typhoonzero 已提交
252
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
253
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
254
  VLOG(40) << "AsyncGRPCServer WaitSeverReady";
T
update  
typhoonzero 已提交
255 256
}

257 258 259 260 261 262 263 264 265 266 267 268 269 270
// Define an option subclass in order to disable SO_REUSEPORT for the
// server socket.
// Come from:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
class NoReusePortOption : public ::grpc::ServerBuilderOption {
 public:
  void UpdateArguments(::grpc::ChannelArguments* args) override {
    args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }

  void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
                         plugins) override {}
};

271
void AsyncGRPCServer::StartServer() {
272
  ::grpc::ServerBuilder builder;
273
  builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
T
typhoonzero 已提交
274
                           &selected_port_);
275

G
gongweibao 已提交
276 277
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
278 279 280 281
  if (FLAGS_rpc_disable_reuse_port) {
    builder.SetOption(
        std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
  }
G
gongweibao 已提交
282 283
  builder.RegisterService(&service_);

284 285 286
  for (auto t : rpc_call_map_) {
    rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
  }
Y
Yancey 已提交
287

G
gongweibao 已提交
288
  server_ = builder.BuildAndStart();
289
  LOG(INFO) << "Server listening on " << bind_address_
T
typhoonzero 已提交
290
            << " selected port: " << selected_port_;
G
gongweibao 已提交
291

292 293 294
  std::function<void(const std::string&, int)> f =
      std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
                std::placeholders::_1, std::placeholders::_2);
X
Xin Pan 已提交
295

296 297 298 299 300
  for (auto& t : rpc_call_map_) {
    auto& rpc_name = t.first;
    auto& cq = rpc_cq_[rpc_name];
    auto threadnum = rpc_thread_num_[rpc_name];
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
301

302 303 304
    reqs.reserve(kRequestBufSize);

    for (int i = 0; i < kRequestBufSize; i++) {
305 306
      VLOG(60) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
               << " I: " << i;
307 308 309 310 311 312
      TryToRegisterNewOne(rpc_name, i);
    }

    for (int i = 0; i < threadnum; i++) {
      rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
          &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
313
      VLOG(40) << t.first << " creates threads!";
314
    }
X
Xin Pan 已提交
315
  }
316

T
wip  
typhoonzero 已提交
317 318 319 320 321
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
322

G
gongweibao 已提交
323 324
  // wait server
  server_->Wait();
325 326 327 328 329

  for (auto& t : rpc_threads_) {
    auto& threads = t.second;
    for (size_t i = 0; i < threads.size(); ++i) {
      threads[i]->join();
330
      VLOG(40) << t.first << " threads ends!";
331
    }
X
Xin Pan 已提交
332
  }
G
gongweibao 已提交
333 334 335
}

void AsyncGRPCServer::ShutdownQueue() {
336 337
  for (auto& t : rpc_cq_) {
    t.second->Shutdown();
338
    VLOG(40) << t.first << " queue shutdown!";
339
  }
G
gongweibao 已提交
340 341
}

342 343
void AsyncGRPCServer::ShutDownImpl() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
T
typhoonzero 已提交
344
  is_shut_down_ = true;
G
gongweibao 已提交
345
  ShutdownQueue();
346

347
  VLOG(40) << "server_ shutdown!";
T
typhoonzero 已提交
348
  server_->Shutdown();
G
gongweibao 已提交
349 350
}

351 352
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
                                          int req_id) {
G
gongweibao 已提交
353 354
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
355
    VLOG(40) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
356 357 358
    return;
  }

359 360
  VLOG(40) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
           << " REQ ID: " << req_id;
T
tangwei12 已提交
361

362 363 364 365 366 367 368 369 370 371 372
  auto& reqs = rpc_reqs_[rpc_name];
  auto& handler = rpc_call_map_[rpc_name];
  auto& cq = rpc_cq_[rpc_name];

  RequestBase* b = nullptr;
  if (rpc_name == kRequestSend) {
    b = new RequestSend(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestGet) {
    b = new RequestGet(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestPrefetch) {
    b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
T
tangwei12 已提交
373
  } else if (rpc_name == kRequestCheckpoint) {
T
tangwei12 已提交
374
    b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
375
  } else {
Q
qiaolongfei 已提交
376
    PADDLE_ENFORCE(false, "not supported rpc");
G
gongweibao 已提交
377 378
  }

379
  reqs[req_id] = b;
380

381
  VLOG(40) << "Create RequestSend status:" << b->Status();
382 383
}

X
Xin Pan 已提交
384
void AsyncGRPCServer::HandleRequest(
385 386
    ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
    std::function<void(const std::string&, int)> TryToRegisterNewOne) {
G
gongweibao 已提交
387 388
  void* tag = NULL;
  bool ok = false;
389

G
gongweibao 已提交
390
  while (true) {
391
    VLOG(40) << "HandleRequest " << rpc_name << " wait next";
G
gongweibao 已提交
392
    if (!cq->Next(&tag, &ok)) {
393
      VLOG(30) << "CompletionQueue " << rpc_name << " shutdown!";
G
gongweibao 已提交
394 395
      break;
    }
Q
qiaolongfei 已提交
396

397
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
398 399
    VLOG(40) << "HandleRequest " << rpc_name << ", req_id:" << req_id
             << " get next";
G
gongweibao 已提交
400

401
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
402 403
    RequestBase* base = nullptr;
    {
404 405 406
      PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
      std::unique_lock<std::mutex> lock(cq_mutex_);
      base = reqs[req_id];
X
Xin Pan 已提交
407
    }
408

409
    VLOG(30) << base->Status2String(rpc_name);
G
gongweibao 已提交
410

G
gongweibao 已提交
411 412 413 414
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
415
    if (!ok) {
416
      LOG(WARNING) << "completion queue:" << rpc_name
G
gongweibao 已提交
417 418
                   << " recv no regular event"
                   << " context:" << base->Status2String(rpc_name);
419
      TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
420 421 422 423 424 425 426 427 428 429
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
        break;
      }
      case FINISH: {
430
        TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
431 432 433 434 435 436 437 438
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

439
}  // namespace distributed
G
gongweibao 已提交
440 441
}  // namespace operators
}  // namespace paddle