grpc_server.cc 13.3 KB
Newer Older
1
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <limits>
#include <string>
G
gongweibao 已提交
17

18
#include "paddle/fluid/operators/distributed/grpc_serde.h"
19
#include "paddle/fluid/operators/distributed/grpc_server.h"
G
gongweibao 已提交
20

21
using ::grpc::ServerAsyncResponseWriter;
X
Xin Pan 已提交
22

G
gongweibao 已提交
23 24
namespace paddle {
namespace operators {
25
namespace distributed {
G
gongweibao 已提交
26 27 28 29 30 31
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
32
  explicit RequestBase(GrpcService::AsyncService* service,
33 34
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
Q
qiaolongfei 已提交
35 36 37
      : service_(service),
        cq_(cq),
        status_(PROCESS),
38 39
        request_handler_(request_handler),
        req_id_(req_id) {
G
gongweibao 已提交
40 41
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
42
  virtual ~RequestBase() {}
43
  virtual void Process() = 0;
G
gongweibao 已提交
44

G
gongweibao 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57
  std::string Status2String(const std::string& method) {
    std::string status = "Process";
    if (status_ == FINISH) {
      status = "Finish";
    }

    std::ostringstream s;
    s << method << " name:[" << GetReqName() << "]"
      << ", ep:[" << ctx_.peer() << "]"
      << " " << status << " using req_id:" << req_id_;
    return s.str();
  }

X
Xin Pan 已提交
58 59 60 61 62 63 64 65 66 67 68 69
  CallStatus Status() const {
    std::lock_guard<std::mutex> l(status_mu_);
    return status_;
  }

  template <typename T>
  void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
    std::lock_guard<std::mutex> l(status_mu_);
    status_ = FINISH;
    responder->Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
  }
70
  virtual std::string GetReqName() = 0;
G
gongweibao 已提交
71 72

 protected:
X
Xin Pan 已提交
73
  mutable std::mutex status_mu_;
74 75 76
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
77
  CallStatus status_;
78 79
  RequestHandler* request_handler_;
  int req_id_;
G
gongweibao 已提交
80 81 82 83
};

class RequestSend final : public RequestBase {
 public:
84
  explicit RequestSend(GrpcService::AsyncService* service,
85 86 87
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
88 89 90
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(),
                                            !request_handler->sync_mode()));
91
    int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
92 93
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
94
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
95 96
  }
  virtual ~RequestSend() {}
97 98 99 100
  std::string GetReqName() override { return request_->Varname(); }

  void Process() override {
    std::string varname = GetReqName();
W
Wu Yi 已提交
101
    VLOG(4) << "RequestSend var_name:" << varname;
G
gongweibao 已提交
102

103 104 105 106 107
    auto scope = request_->GetMutableLocalScope();
    auto invar = request_->GetVar();
    framework::Variable* outvar = nullptr;

    request_handler_->Handle(varname, scope, invar, &outvar);
X
Xin Pan 已提交
108
    Finish(reply_, &responder_);
G
gongweibao 已提交
109 110 111
  }

 protected:
X
Xin Pan 已提交
112
  sendrecv::VoidMessage reply_;
113
  std::shared_ptr<GRPCVariableResponse> request_;
G
gongweibao 已提交
114 115 116 117 118
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
119
  explicit RequestGet(GrpcService::AsyncService* service,
120 121 122
                      ::grpc::ServerCompletionQueue* cq,
                      RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
123
    auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
124 125
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
126
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
127 128 129 130
  }

  virtual ~RequestGet() {}

131
  std::string GetReqName() override { return request_.varname(); }
G
gongweibao 已提交
132

133
  void Process() override {
G
gongweibao 已提交
134
    // proc request.
135
    std::string varname = request_.varname();
W
Wu Yi 已提交
136
    VLOG(4) << "RequestGet " << varname;
137 138 139 140

    auto scope = request_handler_->scope();
    auto invar = scope->FindVar(varname);
    framework::Variable* outvar = nullptr;
141

142 143 144 145 146
    request_handler_->Handle(varname, scope, invar, &outvar);

    if (outvar) {
      SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
                            &reply_);
147
    }
X
Xin Pan 已提交
148
    Finish(reply_, &responder_);
G
gongweibao 已提交
149 150 151 152
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
153
  ::grpc::ByteBuffer reply_;
154
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
155 156
};

157 158 159
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
160 161 162
                           ::grpc::ServerCompletionQueue* cq,
                           RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id),
163
        responder_(&ctx_),
164
        local_scope_(nullptr) {
165 166
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx(), true));
167 168
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
169 170
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
171
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
172 173 174 175
  }

  virtual ~RequestPrefetch() {}

176
  std::string GetReqName() override { return request_->Varname(); }
177

178
  void Process() override {
179
    // prefetch process...
180 181
    std::string in_var_name = request_->Varname();
    std::string out_var_name = request_->OutVarname();
W
Wu Yi 已提交
182
    VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
183
            << " out_var_name: " << out_var_name;
184 185

    auto scope = request_->GetMutableLocalScope();
186
    auto invar = scope->FindVar(in_var_name);
187
    // out var must be created in local scope!
Q
qiaolongfei 已提交
188
    framework::Variable* outvar = scope->Var(out_var_name);
189

Q
qiaolongfei 已提交
190
    request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
Y
Yancey1989 已提交
191

192
    SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
193
                          &reply_);
X
Xin Pan 已提交
194
    Finish(reply_, &responder_);
195 196 197
  }

 protected:
198
  std::shared_ptr<GRPCVariableResponse> request_;
X
Xin Pan 已提交
199
  ::grpc::ByteBuffer reply_;
200
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
201
  framework::Scope* local_scope_;
202 203
};

T
tangwei12 已提交
204 205 206 207 208
class RequestCheckpointNotify final : public RequestBase {
 public:
  explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
                                   ::grpc::ServerCompletionQueue* cq,
                                   RequestHandler* request_handler, int req_id)
T
tangwei12 已提交
209
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
210 211
    request_.reset(new GRPCVariableResponse(request_handler->scope(),
                                            request_handler->dev_ctx()));
T
tangwei12 已提交
212 213
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
T
tangwei12 已提交
214 215 216 217 218 219 220
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestCheckpointNotify() {}

221
  std::string GetReqName() override { return request_->Varname(); }
T
tangwei12 已提交
222 223 224

  void Process() override {
    auto scope = request_->GetMutableLocalScope();
225 226

    std::string checkpoint_notify = request_->Varname();
T
tangwei12 已提交
227
    std::string checkpoint_dir = request_->OutVarname();
228

T
tangwei12 已提交
229 230 231
    VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
            << ", dir: " << checkpoint_dir;

T
tangwei12 已提交
232
    request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
233
                             checkpoint_dir);
T
tangwei12 已提交
234 235
    Finish(reply_, &responder_);
  }
T
tangwei12 已提交
236 237

 protected:
238
  std::shared_ptr<GRPCVariableResponse> request_;
T
tangwei12 已提交
239 240
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
T
tangwei12 已提交
241
};
T
tangwei12 已提交
242

T
done  
typhoonzero 已提交
243
void AsyncGRPCServer::WaitServerReady() {
W
Wu Yi 已提交
244
  VLOG(4) << "AsyncGRPCServer is wait server ready";
T
update  
typhoonzero 已提交
245
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
246
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
W
Wu Yi 已提交
247
  VLOG(4) << "AsyncGRPCServer WaitSeverReady";
T
update  
typhoonzero 已提交
248 249
}

250
void AsyncGRPCServer::StartServer() {
251
  ::grpc::ServerBuilder builder;
252
  builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
T
typhoonzero 已提交
253
                           &selected_port_);
254

G
gongweibao 已提交
255 256
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
257 258
  builder.RegisterService(&service_);

259 260 261
  for (auto t : rpc_call_map_) {
    rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
  }
Y
Yancey 已提交
262

G
gongweibao 已提交
263
  server_ = builder.BuildAndStart();
264
  LOG(INFO) << "Server listening on " << bind_address_
T
typhoonzero 已提交
265
            << " selected port: " << selected_port_;
G
gongweibao 已提交
266

267 268 269
  std::function<void(const std::string&, int)> f =
      std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
                std::placeholders::_1, std::placeholders::_2);
X
Xin Pan 已提交
270

271 272 273 274 275
  for (auto& t : rpc_call_map_) {
    auto& rpc_name = t.first;
    auto& cq = rpc_cq_[rpc_name];
    auto threadnum = rpc_thread_num_[rpc_name];
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
276

277 278 279
    reqs.reserve(kRequestBufSize);

    for (int i = 0; i < kRequestBufSize; i++) {
T
tangwei12 已提交
280
      VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
281 282 283 284 285 286
      TryToRegisterNewOne(rpc_name, i);
    }

    for (int i = 0; i < threadnum; i++) {
      rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
          &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
W
Wu Yi 已提交
287
      VLOG(4) << t.first << " creates threads!";
288
    }
X
Xin Pan 已提交
289
  }
290

T
wip  
typhoonzero 已提交
291 292 293 294 295
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
296

G
gongweibao 已提交
297 298
  // wait server
  server_->Wait();
299 300 301 302 303

  for (auto& t : rpc_threads_) {
    auto& threads = t.second;
    for (size_t i = 0; i < threads.size(); ++i) {
      threads[i]->join();
W
Wu Yi 已提交
304
      VLOG(4) << t.first << " threads ends!";
305
    }
X
Xin Pan 已提交
306
  }
G
gongweibao 已提交
307 308 309
}

void AsyncGRPCServer::ShutdownQueue() {
310 311
  for (auto& t : rpc_cq_) {
    t.second->Shutdown();
W
Wu Yi 已提交
312
    VLOG(4) << t.first << " queue shutdown!";
313
  }
G
gongweibao 已提交
314 315
}

316 317
void AsyncGRPCServer::ShutDownImpl() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
T
typhoonzero 已提交
318
  is_shut_down_ = true;
G
gongweibao 已提交
319
  ShutdownQueue();
320

W
Wu Yi 已提交
321
  VLOG(4) << "server_ shutdown!";
T
typhoonzero 已提交
322
  server_->Shutdown();
G
gongweibao 已提交
323 324
}

325 326
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
                                          int req_id) {
G
gongweibao 已提交
327 328
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
W
Wu Yi 已提交
329
    VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
330 331 332
    return;
  }

T
tangwei12 已提交
333 334
  VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
          << " REQ ID: " << req_id;
T
tangwei12 已提交
335

336 337 338 339 340 341 342 343 344 345 346
  auto& reqs = rpc_reqs_[rpc_name];
  auto& handler = rpc_call_map_[rpc_name];
  auto& cq = rpc_cq_[rpc_name];

  RequestBase* b = nullptr;
  if (rpc_name == kRequestSend) {
    b = new RequestSend(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestGet) {
    b = new RequestGet(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestPrefetch) {
    b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
T
tangwei12 已提交
347
  } else if (rpc_name == kRequestCheckpoint) {
T
tangwei12 已提交
348
    b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
349
  } else {
Q
qiaolongfei 已提交
350
    PADDLE_ENFORCE(false, "not supported rpc");
G
gongweibao 已提交
351 352
  }

353
  reqs[req_id] = b;
354

355
  VLOG(4) << "Create RequestSend status:" << b->Status();
356 357
}

X
Xin Pan 已提交
358
void AsyncGRPCServer::HandleRequest(
359 360
    ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
    std::function<void(const std::string&, int)> TryToRegisterNewOne) {
G
gongweibao 已提交
361 362
  void* tag = NULL;
  bool ok = false;
363

G
gongweibao 已提交
364
  while (true) {
G
gongweibao 已提交
365
    VLOG(4) << "HandleRequest " << rpc_name << " wait next";
G
gongweibao 已提交
366
    if (!cq->Next(&tag, &ok)) {
T
tangwei12 已提交
367
      VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
G
gongweibao 已提交
368 369
      break;
    }
Q
qiaolongfei 已提交
370

371
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
G
gongweibao 已提交
372
    VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
373
            << " get next";
G
gongweibao 已提交
374

375
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
376 377
    RequestBase* base = nullptr;
    {
378 379 380
      PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
      std::unique_lock<std::mutex> lock(cq_mutex_);
      base = reqs[req_id];
X
Xin Pan 已提交
381
    }
382

G
gongweibao 已提交
383 384
    VLOG(3) << base->Status2String(rpc_name);

G
gongweibao 已提交
385 386 387 388
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
389
    if (!ok) {
390
      LOG(WARNING) << "completion queue:" << rpc_name
G
gongweibao 已提交
391 392
                   << " recv no regular event"
                   << " context:" << base->Status2String(rpc_name);
393
      TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
394 395 396 397 398 399 400 401 402 403
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
        break;
      }
      case FINISH: {
404
        TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
405 406 407 408 409 410 411 412
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

413
}  // namespace distributed
G
gongweibao 已提交
414 415
}  // namespace operators
}  // namespace paddle