You need to sign in or sign up before continuing.
grpc_server.cc 13.2 KB
Newer Older
1
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <limits>
#include <string>
G
gongweibao 已提交
17

18
#include "paddle/fluid/operators/distributed/grpc_server.h"
G
gongweibao 已提交
19

20
using ::grpc::ServerAsyncResponseWriter;
X
Xin Pan 已提交
21

G
gongweibao 已提交
22 23
namespace paddle {
namespace operators {
24
namespace distributed {
G
gongweibao 已提交
25 26 27 28 29 30
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
31
  explicit RequestBase(GrpcService::AsyncService* service,
32 33
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
Q
qiaolongfei 已提交
34 35 36
      : service_(service),
        cq_(cq),
        status_(PROCESS),
37 38
        request_handler_(request_handler),
        req_id_(req_id) {
G
gongweibao 已提交
39 40
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
41
  virtual ~RequestBase() {}
42
  virtual void Process() = 0;
G
gongweibao 已提交
43

G
gongweibao 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56
  std::string Status2String(const std::string& method) {
    std::string status = "Process";
    if (status_ == FINISH) {
      status = "Finish";
    }

    std::ostringstream s;
    s << method << " name:[" << GetReqName() << "]"
      << ", ep:[" << ctx_.peer() << "]"
      << " " << status << " using req_id:" << req_id_;
    return s.str();
  }

X
Xin Pan 已提交
57 58 59 60 61 62 63 64 65 66 67 68
  CallStatus Status() const {
    std::lock_guard<std::mutex> l(status_mu_);
    return status_;
  }

  template <typename T>
  void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
    std::lock_guard<std::mutex> l(status_mu_);
    status_ = FINISH;
    responder->Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
  }
69
  virtual std::string GetReqName() = 0;
G
gongweibao 已提交
70 71

 protected:
X
Xin Pan 已提交
72
  mutable std::mutex status_mu_;
73 74 75
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
76
  CallStatus status_;
77 78
  RequestHandler* request_handler_;
  int req_id_;
G
gongweibao 已提交
79 80 81 82
};

class RequestSend final : public RequestBase {
 public:
83
  explicit RequestSend(GrpcService::AsyncService* service,
84 85 86 87 88 89
                       ::grpc::ServerCompletionQueue* cq,
                       RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
    request_.reset(new VariableResponse(request_handler->scope(),
                                        request_handler->dev_ctx(),
                                        !request_handler->sync_mode()));
90
    int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
91 92
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
93
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
94 95
  }
  virtual ~RequestSend() {}
96 97 98 99
  std::string GetReqName() override { return request_->Varname(); }

  void Process() override {
    std::string varname = GetReqName();
W
Wu Yi 已提交
100
    VLOG(4) << "RequestSend var_name:" << varname;
G
gongweibao 已提交
101

102 103 104 105 106
    auto scope = request_->GetMutableLocalScope();
    auto invar = request_->GetVar();
    framework::Variable* outvar = nullptr;

    request_handler_->Handle(varname, scope, invar, &outvar);
X
Xin Pan 已提交
107
    Finish(reply_, &responder_);
G
gongweibao 已提交
108 109 110
  }

 protected:
X
Xin Pan 已提交
111
  sendrecv::VoidMessage reply_;
112
  std::shared_ptr<VariableResponse> request_;
G
gongweibao 已提交
113 114 115 116 117
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
118
  explicit RequestGet(GrpcService::AsyncService* service,
119 120 121
                      ::grpc::ServerCompletionQueue* cq,
                      RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
122
    auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
123 124
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
125
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
126 127 128 129
  }

  virtual ~RequestGet() {}

130
  std::string GetReqName() override { return request_.varname(); }
G
gongweibao 已提交
131

132
  void Process() override {
G
gongweibao 已提交
133
    // proc request.
134
    std::string varname = request_.varname();
W
Wu Yi 已提交
135
    VLOG(4) << "RequestGet " << varname;
136 137 138 139

    auto scope = request_handler_->scope();
    auto invar = scope->FindVar(varname);
    framework::Variable* outvar = nullptr;
140

141 142 143 144 145
    request_handler_->Handle(varname, scope, invar, &outvar);

    if (outvar) {
      SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
                            &reply_);
146
    }
X
Xin Pan 已提交
147
    Finish(reply_, &responder_);
G
gongweibao 已提交
148 149 150 151
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
152
  ::grpc::ByteBuffer reply_;
153
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
154 155
};

156 157 158
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
159 160 161
                           ::grpc::ServerCompletionQueue* cq,
                           RequestHandler* request_handler, int req_id)
      : RequestBase(service, cq, request_handler, req_id),
162
        responder_(&ctx_),
163 164 165
        local_scope_(nullptr) {
    request_.reset(new VariableResponse(request_handler->scope(),
                                        request_handler->dev_ctx(), true));
166 167
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
168 169
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
170
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
171 172 173 174
  }

  virtual ~RequestPrefetch() {}

175
  std::string GetReqName() override { return request_->Varname(); }
176

177
  void Process() override {
178
    // prefetch process...
179 180
    std::string in_var_name = request_->Varname();
    std::string out_var_name = request_->OutVarname();
W
Wu Yi 已提交
181
    VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
182
            << " out_var_name: " << out_var_name;
183 184

    auto scope = request_->GetMutableLocalScope();
185
    auto invar = scope->FindVar(in_var_name);
186
    // out var must be created in local scope!
Q
qiaolongfei 已提交
187
    framework::Variable* outvar = scope->Var(out_var_name);
188

Q
qiaolongfei 已提交
189
    request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
Y
Yancey1989 已提交
190

191
    SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
192
                          &reply_);
X
Xin Pan 已提交
193
    Finish(reply_, &responder_);
194 195 196
  }

 protected:
Y
Yancey1989 已提交
197
  std::shared_ptr<VariableResponse> request_;
X
Xin Pan 已提交
198
  ::grpc::ByteBuffer reply_;
199
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
200
  framework::Scope* local_scope_;
201 202
};

T
tangwei12 已提交
203 204 205 206 207
class RequestCheckpointNotify final : public RequestBase {
 public:
  explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
                                   ::grpc::ServerCompletionQueue* cq,
                                   RequestHandler* request_handler, int req_id)
T
tangwei12 已提交
208
      : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
T
tangwei12 已提交
209
    request_.reset(new VariableResponse(request_handler->scope(),
T
tangwei12 已提交
210
                                        request_handler->dev_ctx()));
T
tangwei12 已提交
211 212
    int method_id =
        static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
T
tangwei12 已提交
213 214 215 216 217 218 219
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
  }

  virtual ~RequestCheckpointNotify() {}

220
  std::string GetReqName() override { return request_->Varname(); }
T
tangwei12 已提交
221 222 223

  void Process() override {
    auto scope = request_->GetMutableLocalScope();
224 225

    std::string checkpoint_notify = request_->Varname();
T
tangwei12 已提交
226
    std::string checkpoint_dir = request_->OutVarname();
227

T
tangwei12 已提交
228 229 230
    VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
            << ", dir: " << checkpoint_dir;

T
tangwei12 已提交
231
    request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
232
                             checkpoint_dir);
T
tangwei12 已提交
233 234
    Finish(reply_, &responder_);
  }
T
tangwei12 已提交
235 236

 protected:
T
tangwei12 已提交
237
  std::shared_ptr<VariableResponse> request_;
T
tangwei12 已提交
238 239
  sendrecv::VoidMessage reply_;
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
T
tangwei12 已提交
240
};
T
tangwei12 已提交
241

T
done  
typhoonzero 已提交
242
void AsyncGRPCServer::WaitServerReady() {
W
Wu Yi 已提交
243
  VLOG(4) << "AsyncGRPCServer is wait server ready";
T
update  
typhoonzero 已提交
244
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
245
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
W
Wu Yi 已提交
246
  VLOG(4) << "AsyncGRPCServer WaitSeverReady";
T
update  
typhoonzero 已提交
247 248
}

249
void AsyncGRPCServer::StartServer() {
250
  ::grpc::ServerBuilder builder;
251
  builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
T
typhoonzero 已提交
252
                           &selected_port_);
253

G
gongweibao 已提交
254 255
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
256 257
  builder.RegisterService(&service_);

258 259 260
  for (auto t : rpc_call_map_) {
    rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
  }
Y
Yancey 已提交
261

G
gongweibao 已提交
262
  server_ = builder.BuildAndStart();
263
  LOG(INFO) << "Server listening on " << bind_address_
T
typhoonzero 已提交
264
            << " selected port: " << selected_port_;
G
gongweibao 已提交
265

266 267 268
  std::function<void(const std::string&, int)> f =
      std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
                std::placeholders::_1, std::placeholders::_2);
X
Xin Pan 已提交
269

270 271 272 273 274
  for (auto& t : rpc_call_map_) {
    auto& rpc_name = t.first;
    auto& cq = rpc_cq_[rpc_name];
    auto threadnum = rpc_thread_num_[rpc_name];
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
275

276 277 278
    reqs.reserve(kRequestBufSize);

    for (int i = 0; i < kRequestBufSize; i++) {
T
tangwei12 已提交
279
      VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
280 281 282 283 284 285
      TryToRegisterNewOne(rpc_name, i);
    }

    for (int i = 0; i < threadnum; i++) {
      rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
          &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
W
Wu Yi 已提交
286
      VLOG(4) << t.first << " creates threads!";
287
    }
X
Xin Pan 已提交
288
  }
289

T
wip  
typhoonzero 已提交
290 291 292 293 294
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
295

G
gongweibao 已提交
296 297
  // wait server
  server_->Wait();
298 299 300 301 302

  for (auto& t : rpc_threads_) {
    auto& threads = t.second;
    for (size_t i = 0; i < threads.size(); ++i) {
      threads[i]->join();
W
Wu Yi 已提交
303
      VLOG(4) << t.first << " threads ends!";
304
    }
X
Xin Pan 已提交
305
  }
G
gongweibao 已提交
306 307 308
}

void AsyncGRPCServer::ShutdownQueue() {
309 310
  for (auto& t : rpc_cq_) {
    t.second->Shutdown();
W
Wu Yi 已提交
311
    VLOG(4) << t.first << " queue shutdown!";
312
  }
G
gongweibao 已提交
313 314
}

315 316
void AsyncGRPCServer::ShutDownImpl() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
T
typhoonzero 已提交
317
  is_shut_down_ = true;
G
gongweibao 已提交
318
  ShutdownQueue();
319

W
Wu Yi 已提交
320
  VLOG(4) << "server_ shutdown!";
T
typhoonzero 已提交
321
  server_->Shutdown();
G
gongweibao 已提交
322 323
}

324 325
void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
                                          int req_id) {
G
gongweibao 已提交
326 327
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
W
Wu Yi 已提交
328
    VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
329 330 331
    return;
  }

T
tangwei12 已提交
332 333
  VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
          << " REQ ID: " << req_id;
T
tangwei12 已提交
334

335 336 337 338 339 340 341 342 343 344 345
  auto& reqs = rpc_reqs_[rpc_name];
  auto& handler = rpc_call_map_[rpc_name];
  auto& cq = rpc_cq_[rpc_name];

  RequestBase* b = nullptr;
  if (rpc_name == kRequestSend) {
    b = new RequestSend(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestGet) {
    b = new RequestGet(&service_, cq.get(), handler, req_id);
  } else if (rpc_name == kRequestPrefetch) {
    b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
T
tangwei12 已提交
346
  } else if (rpc_name == kRequestCheckpoint) {
T
tangwei12 已提交
347
    b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
348
  } else {
Q
qiaolongfei 已提交
349
    PADDLE_ENFORCE(false, "not supported rpc");
G
gongweibao 已提交
350 351
  }

352
  reqs[req_id] = b;
353

354
  VLOG(4) << "Create RequestSend status:" << b->Status();
355 356
}

X
Xin Pan 已提交
357
void AsyncGRPCServer::HandleRequest(
358 359
    ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
    std::function<void(const std::string&, int)> TryToRegisterNewOne) {
G
gongweibao 已提交
360 361
  void* tag = NULL;
  bool ok = false;
362

G
gongweibao 已提交
363
  while (true) {
G
gongweibao 已提交
364
    VLOG(4) << "HandleRequest " << rpc_name << " wait next";
G
gongweibao 已提交
365
    if (!cq->Next(&tag, &ok)) {
T
tangwei12 已提交
366
      VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
G
gongweibao 已提交
367 368
      break;
    }
Q
qiaolongfei 已提交
369

370
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
G
gongweibao 已提交
371
    VLOG(4) << "HandleRequest " << rpc_name << ", req_id:" << req_id
372
            << " get next";
G
gongweibao 已提交
373

374
    auto& reqs = rpc_reqs_[rpc_name];
X
Xin Pan 已提交
375 376
    RequestBase* base = nullptr;
    {
377 378 379
      PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
      std::unique_lock<std::mutex> lock(cq_mutex_);
      base = reqs[req_id];
X
Xin Pan 已提交
380
    }
381

G
gongweibao 已提交
382 383
    VLOG(3) << base->Status2String(rpc_name);

G
gongweibao 已提交
384 385 386 387
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
388
    if (!ok) {
389
      LOG(WARNING) << "completion queue:" << rpc_name
G
gongweibao 已提交
390 391
                   << " recv no regular event"
                   << " context:" << base->Status2String(rpc_name);
392
      TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
393 394 395 396 397 398 399 400 401 402
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
        break;
      }
      case FINISH: {
403
        TryToRegisterNewOne(rpc_name, req_id);
G
gongweibao 已提交
404 405 406 407 408 409 410 411
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

412
}  // namespace distributed
G
gongweibao 已提交
413 414
}  // namespace operators
}  // namespace paddle