grpc_server.cc 14.3 KB
Newer Older
1
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <limits>
#include <string>
G
gongweibao 已提交
17

18
#include "paddle/fluid/operators/distributed/grpc_serde.h"
19
#include "paddle/fluid/operators/distributed/grpc_server.h"
G
gongweibao 已提交
20

21
using ::grpc::ServerAsyncResponseWriter;
X
Xin Pan 已提交
22

23 24
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28 29 30 31 32 33
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
34
  explicit RequestBase(GrpcService::AsyncService* service,
35 36
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
Q
qiaolongfei 已提交
37 38 39
      : service_(service),
        cq_(cq),
        status_(PROCESS),
40 41
        request_handler_(request_handler),
        req_id_(req_id) {
G
gongweibao 已提交
42 43
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
44
  virtual ~RequestBase() {}
45
  virtual void Process() = 0;
G
gongweibao 已提交
46

G
gongweibao 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
  std::string Status2String(const std::string& method) {
    std::string status = "Process";
    if (status_ == FINISH) {
      status = "Finish";
    }

    std::ostringstream s;
    s << method << " name:[" << GetReqName() << "]"
      << ", ep:[" << ctx_.peer() << "]"
      << " " << status << " using req_id:" << req_id_;
    return s.str();
  }

X
Xin Pan 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  CallStatus Status() const {
    std::lock_guard<std::mutex> l(status_mu_);
    return status_;
  }

  template <typename T>
  void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
    std::lock_guard<std::mutex> l(status_mu_);
    status_ = FINISH;
    responder->Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
  }
72
  virtual std::string GetReqName() = 0;
G
gongweibao 已提交
73 74

 protected:
X
Xin Pan 已提交
75
  mutable std::mutex status_mu_;
76 77 78
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
79
  CallStatus status_;
80 81
  RequestHandler* request_handler_;
  int req_id_;
G
gongweibao 已提交
82 83 84 85
};

class RequestSend final : public RequestBase {
 public:
86
  explicit RequestSend(GrpcService::AsyncService* service,
87 88 89
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
90 91 92
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(),
                                            !request_handler->sync_mode()));
93
    int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
94 95
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
96
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
97 98
  }
  virtual ~RequestSend() {}
99 100 101 102
  std::string GetReqName() override { return request_->Varname(); }

  void Process() override {
    std::string varname = GetReqName();
M
minqiyang 已提交
103
    VLOG(4) << "RequestSend var_name:" << varname;
G
gongweibao 已提交
104

105 106
    auto scope = request_->GetMutableLocalScope();
    auto invar = request_->GetVar();
W
Wu Yi 已提交
107
    int trainer_id = request_->GetTrainerId();
108 109
    framework::Variable* outvar = nullptr;

W
Wu Yi 已提交
110
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
X
Xin Pan 已提交
111
    Finish(reply_, &responder_);
G
gongweibao 已提交
112 113 114
  }

 protected:
X
Xin Pan 已提交
115
  sendrecv::VoidMessage reply_;
116
  std::shared_ptr<GRPCVariableResponse> request_;
G
gongweibao 已提交
117 118 119 120 121
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
122
  explicit RequestGet(GrpcService::AsyncService* service,
123 124 125
                      ::grpc::ServerCompletionQueue* cq,
                      RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
126
    auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
127 128
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
129
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
130 131 132 133
  }

  virtual ~RequestGet() {}

134
  std::string GetReqName() override { return request_.varname(); }
G
gongweibao 已提交
135

136
  void Process() override {
G
gongweibao 已提交
137
    // proc request.
138
    std::string varname = request_.varname();
W
Wu Yi 已提交
139
    int trainer_id = request_.trainer_id();
M
minqiyang 已提交
140
    VLOG(4) << "RequestGet " << varname;
141 142 143 144

    auto scope = request_handler_->scope();
    auto invar = scope->FindVar(varname);
    framework::Variable* outvar = nullptr;
145

W
Wu Yi 已提交
146
    request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
147 148 149 150

    if (outvar) {
      SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
                            &reply_);
151
    }
X
Xin Pan 已提交
152
    Finish(reply_, &responder_);
G
gongweibao 已提交
153 154 155 156
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
157
  ::grpc::ByteBuffer reply_;
158
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
159 160
};

161 162 163
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
164 165 166
                           ::grpc::ServerCompletionQueue* cq,
                           RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id),
167
        responder_(&ctx_),
168
        local_scope_(nullptr) {
169 170
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(), true));
171 172
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
173 174
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
175
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
176 177 178 179
  }

  virtual ~RequestPrefetch() {}

180
  std::string GetReqName() override { return request_->Varname(); }
181

182
  void Process() override {
183
    // prefetch process...
184 185
    std::string in_var_name = request_->Varname();
    std::string out_var_name = request_->OutVarname();
186
    std::string table_name = request_->TableName();
W
Wu Yi 已提交
187
    int trainer_id = request_->GetTrainerId();
M
minqiyang 已提交
188 189
    VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
            << " out_var_name: " << out_var_name;
190 191

    auto scope = request_->GetMutableLocalScope();
192
    auto invar = scope->FindVar(in_var_name);
193
    // out var must be created in local scope!
Q
qiaolongfei 已提交
194
    framework::Variable* outvar = scope->Var(out_var_name);
195

W
Wu Yi 已提交
196
    request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
197
                             out_var_name, table_name);
Y
Yancey1989 已提交
198

199
    SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
200
                          &reply_);
X
Xin Pan 已提交
201
    Finish(reply_, &responder_);
202 203 204
  }

 protected:
205
  std::shared_ptr<GRPCVariableResponse> request_;
X
Xin Pan 已提交
206
  ::grpc::ByteBuffer reply_;
207
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
208
  framework::Scope* local_scope_;
209 210
};

T
tangwei12 已提交
211 212 213 214 215
class RequestCheckpointNotify final : public RequestBase {
 public:
  explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
                                   ::grpc::ServerCompletionQueue* cq,
                                   RequestHandler* request_handler, int req_id)
T
tangwei12 已提交
216
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
217 218
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx()));
T
tangwei12 已提交
219 220
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
T
tangwei12 已提交
221 222 223 224 225 226 227
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestCheckpointNotify() {}

228
  std::string GetReqName() override { return request_->Varname(); }
T
tangwei12 已提交
229 230 231

  void Process() override {
    auto scope = request_->GetMutableLocalScope();
232 233

    std::string checkpoint_notify = request_->Varname();
T
tangwei12 已提交
234
    std::string checkpoint_dir = request_->OutVarname();
W
Wu Yi 已提交
235
    int trainer_id = request_->GetTrainerId();
236

M
minqiyang 已提交
237 238
    VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
            << ", dir: " << checkpoint_dir;
T
tangwei12 已提交
239

T
tangwei12 已提交
240
    request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
W
Wu Yi 已提交
241
                             trainer_id, checkpoint_dir);
T
tangwei12 已提交
242 243
    Finish(reply_, &responder_);
  }
T
tangwei12 已提交
244 245

 protected:
246
  std::shared_ptr<GRPCVariableResponse> request_;
T
tangwei12 已提交
247 248
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
T
tangwei12 已提交
249
};
T
tangwei12 已提交
250

T
done  
typhoonzero 已提交
251
void AsyncGRPCServer::WaitServerReady() {
M
minqiyang 已提交
252
  VLOG(4) << "AsyncGRPCServer is wait server ready";
T
update  
typhoonzero 已提交
253
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
254
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
M
minqiyang 已提交
255
  VLOG(4) << "AsyncGRPCServer WaitSeverReady";
T
update  
typhoonzero 已提交
256 257
}

258 259 260 261 262 263 264 265 266 267 268 269 270 271
// Define an option subclass in order to disable SO_REUSEPORT for the
// server socket.
// Come from:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
class NoReusePortOption : public ::grpc::ServerBuilderOption {
 public:
  void UpdateArguments(::grpc::ChannelArguments* args) override {
    args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }

  void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
                         plugins) override {}
};

272
void AsyncGRPCServer::StartServer() {
273
  ::grpc::ServerBuilder builder;
274
  builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
T
typhoonzero 已提交
275
                           &selected_port_);
276

G
gongweibao 已提交
277 278
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
279 280 281 282
  if (FLAGS_rpc_disable_reuse_port) {
    builder.SetOption(
        std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
  }
G
gongweibao 已提交
283 284
  builder.RegisterService(&service_);

285 286 287
  for (auto t : rpc_call_map_) {
    rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
  }
Y
Yancey 已提交
288

G
gongweibao 已提交
289
  server_ = builder.BuildAndStart();
290
  LOG(INFO) << "Server listening on " << bind_address_
T
typhoonzero 已提交
291
            << " selected port: " << selected_port_;
G
gongweibao 已提交
292

293 294 295
  std::function<void(const std::string&, int)> f =
      std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
                std::placeholders::_1, std::placeholders::_2);
X
Xin Pan 已提交
296

297 298 299 300 301
  for (auto& t : rpc_call_map_) {
    auto& rpc_name = t.first;
    auto& cq = rpc_cq_[rpc_name];
    auto threadnum = rpc_thread_num_[rpc_name];
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
302

303 304 305
    reqs.reserve(kRequestBufSize);

    for (int i = 0; i < kRequestBufSize; i++) {
M
minqiyang 已提交
306
      VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
307 308 309 310 311 312
      TryToRegisterNewOne(rpc_name, i);
    }

    for (int i = 0; i < threadnum; i++) {
      rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
          &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
M
minqiyang 已提交
313
      VLOG(4) << t.first << " creates threads!";
314
    }
X
Xin Pan 已提交
315
  }
316

T
wip  
typhoonzero 已提交
317 318 319 320 321
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
322

G
gongweibao 已提交
323 324
  // wait server
  server_->Wait();
325 326 327 328 329

  for (auto& t : rpc_threads_) {
    auto& threads = t.second;
    for (size_t i = 0; i < threads.size(); ++i) {
      threads[i]->join();
M
minqiyang 已提交
330
      VLOG(4) << t.first << " threads ends!";
331
    }
X
Xin Pan 已提交
332
  }
G
gongweibao 已提交
333 334 335
}

void AsyncGRPCServer::ShutdownQueue() {
336 337
  for (auto& t : rpc_cq_) {
    t.second->Shutdown();
M
minqiyang 已提交
338
    VLOG(4) << t.first << " queue shutdown!";
339
  }
G
gongweibao 已提交
340 341
}

342 343
void AsyncGRPCServer::ShutDownImpl() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
T
typhoonzero 已提交
344
  is_shut_down_ = true;
G
gongweibao 已提交
345
  ShutdownQueue();
346

M
minqiyang 已提交
347
  VLOG(4) << "server_ shutdown!";
T
typhoonzero 已提交
348
  server_->Shutdown();
G
gongweibao 已提交
349 350
}

351 352
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
                                          int req_id) {
G
gongweibao 已提交
353 354
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
M
minqiyang 已提交
355
    VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
356 357 358
    return;
  }

M
minqiyang 已提交
359 360
  VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
          << " REQ ID: " << req_id;
T
tangwei12 已提交
361

362 363 364 365 366 367 368 369 370 371 372
  auto& reqs = rpc_reqs_[rpc_name];
  auto& handler = rpc_call_map_[rpc_name];
  auto& cq = rpc_cq_[rpc_name];

  RequestBase* b = nullptr;
  if (rpc_name == kRequestSend) {
    b = new RequestSend(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestGet) {
    b = new RequestGet(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestPrefetch) {
    b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
T
tangwei12 已提交
373
  } else if (rpc_name == kRequestCheckpoint) {
T
tangwei12 已提交
374
    b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
375
  } else {
Q
qiaolongfei 已提交
376
    PADDLE_ENFORCE(false, "not supported rpc");
G
gongweibao 已提交
377 378
  }

379
  reqs[req_id] = b;
380

M
minqiyang 已提交
381
  VLOG(4) << "Create RequestSend status:" << b->Status();
382 383
}

X
Xin Pan 已提交
384
void AsyncGRPCServer::HandleRequest(
385 386
    ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
    std::function<void(const std::string&, int)> TryToRegisterNewOne) {
G
gongweibao 已提交
387 388
  void* tag = NULL;
  bool ok = false;
389

G
gongweibao 已提交
390
  while (true) {
M
minqiyang 已提交
391
    VLOG(4) << "HandleRequest " << rpc_name << " wait next";
G
gongweibao 已提交
392
    if (!cq->Next(&tag, &ok)) {
M
minqiyang 已提交
393
      VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
G
gongweibao 已提交
394 395
      break;
    }
Q
qiaolongfei 已提交
396

397
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
M
minqiyang 已提交
398 399
    VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
            << " get next";
G
gongweibao 已提交
400

401
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
402 403
    RequestBase* base = nullptr;
    {
404 405 406
      PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
      std::unique_lock<std::mutex> lock(cq_mutex_);
      base = reqs[req_id];
X
Xin Pan 已提交
407
    }
408

M
minqiyang 已提交
409
    VLOG(3) << base->Status2String(rpc_name);
G
gongweibao 已提交
410

G
gongweibao 已提交
411 412 413 414
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
415
    if (!ok) {
416
      LOG(WARNING) << "completion queue:" << rpc_name
G
gongweibao 已提交
417 418
                   << " recv no regular event"
                   << " context:" << base->Status2String(rpc_name);
419
      TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
420 421 422 423 424 425 426 427 428 429
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
        break;
      }
      case FINISH: {
430
        TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
431 432 433 434 435 436 437 438
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

439
}  // namespace distributed
G
gongweibao 已提交
440 441
}  // namespace operators
}  // namespace paddle