grpc_server.cc 14.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
16 17 18

#include <limits>
#include <string>
G
gongweibao 已提交
19

20
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
21 22 23 24

namespace paddle {
namespace operators {
namespace detail {
X
Xin Pan 已提交
25 26 27 28
namespace {
const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20;
}  // namespace
G
gongweibao 已提交
29 30 31 32 33 34
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
35
  explicit RequestBase(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
36
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
37
                       const platform::DeviceContext* dev_ctx)
Q
qiaolongfei 已提交
38 39 40 41 42
      : service_(service),
        cq_(cq),
        sync_mode_(sync_mode),
        status_(PROCESS),
        dev_ctx_(dev_ctx) {
G
gongweibao 已提交
43 44
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
45 46 47 48 49
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
50 51 52 53
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
54 55

 protected:
56 57 58
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
Q
qiaolongfei 已提交
59
  const bool sync_mode_;
G
gongweibao 已提交
60
  CallStatus status_;
61
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
62 63 64 65
};

class RequestSend final : public RequestBase {
 public:
66
  explicit RequestSend(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
67
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
68
                       framework::Scope* scope, ReceivedQueue* queue,
X
Xin Pan 已提交
69
                       const platform::DeviceContext* dev_ctx, int i)
Q
qiaolongfei 已提交
70 71
      : RequestBase(service, cq, sync_mode, dev_ctx),
        queue_(queue),
X
Xin Pan 已提交
72 73
        responder_(&ctx_),
        i_(i) {
Q
qiaolongfei 已提交
74
    if (sync_mode_) {
75
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
76
    } else {
77
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
78
    }
79
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
80 81 82
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(i)));
G
gongweibao 已提交
83 84 85 86
  }

  virtual ~RequestSend() {}

87
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
88

G
gongweibao 已提交
89
  virtual void Process() {
Q
qiaolongfei 已提交
90 91 92
    std::string var_name = GetReqName();
    VLOG(3) << "RequestSend " << var_name;
    queue_->Push(std::make_pair(var_name, request_));
93

G
gongweibao 已提交
94
    status_ = FINISH;
X
Xin Pan 已提交
95 96
    responder_.Finish(reply_, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
G
gongweibao 已提交
97 98 99
  }

 protected:
X
Xin Pan 已提交
100
  sendrecv::VoidMessage reply_;
101 102
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
103
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
X
Xin Pan 已提交
104
  int i_;
G
gongweibao 已提交
105 106 107 108
};

class RequestGet final : public RequestBase {
 public:
109
  explicit RequestGet(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
110
                      ::grpc::ServerCompletionQueue* cq, bool sync_mode,
111
                      framework::Scope* scope,
T
typhoonzero 已提交
112
                      const platform::DeviceContext* dev_ctx,
X
Xin Pan 已提交
113
                      framework::BlockingQueue<MessageWithName>* queue, int i)
Q
qiaolongfei 已提交
114
      : RequestBase(service, cq, sync_mode, dev_ctx),
Y
Yancey1989 已提交
115 116
        responder_(&ctx_),
        scope_(scope),
X
Xin Pan 已提交
117 118
        queue_(queue),
        i_(i) {
Q
qiaolongfei 已提交
119
    auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
120 121 122
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(i)));
G
gongweibao 已提交
123 124 125 126
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
127 128
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
129 130 131
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
Q
qiaolongfei 已提交
132
    VLOG(3) << "RequestGet " << var_name;
G
gongweibao 已提交
133
    auto* var = scope_->FindVar(var_name);
134

135
    if (var_name != FETCH_BARRIER_MESSAGE) {
X
Xin Pan 已提交
136
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
137
    }
138

G
gongweibao 已提交
139
    status_ = FINISH;
X
Xin Pan 已提交
140 141
    responder_.Finish(reply_, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
142 143 144 145 146 147

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
148 149 150 151
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
152
  ::grpc::ByteBuffer reply_;
153
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
154
  framework::Scope* scope_;
T
typhoonzero 已提交
155
  framework::BlockingQueue<MessageWithName>* queue_;
X
Xin Pan 已提交
156
  int i_;
G
gongweibao 已提交
157 158
};

159 160 161
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
162
                           ::grpc::ServerCompletionQueue* cq, bool sync_mode,
163 164 165
                           framework::Scope* scope,
                           const platform::DeviceContext* dev_ctx,
                           framework::Executor* executor,
Y
Yancey1989 已提交
166
                           framework::ProgramDesc* program,
X
Xin Pan 已提交
167 168
                           framework::ExecutorPrepareContext* prefetch_ctx,
                           int i)
Q
qiaolongfei 已提交
169
      : RequestBase(service, cq, sync_mode, dev_ctx),
170 171 172 173
        responder_(&ctx_),
        scope_(scope),
        executor_(executor),
        program_(program),
X
Xin Pan 已提交
174 175
        prefetch_ctx_(prefetch_ctx),
        i_(i) {
Q
qiaolongfei 已提交
176
    if (sync_mode_) {
177
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
178
    } else {
179
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
180
    }
181
    int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
Y
Yancey1989 已提交
182 183
    service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
                                cq_, cq_, this);
184 185 186 187
  }

  virtual ~RequestPrefetch() {}

Y
Yancey1989 已提交
188
  virtual std::string GetReqName() { return request_->Varname(); }
189 190 191

  virtual void Process() {
    // prefetch process...
Y
Yancey1989 已提交
192
    ::grpc::ByteBuffer reply;
193

Y
Yancey1989 已提交
194
    std::string var_name = request_->OutVarname();
Q
qiaolongfei 已提交
195
    VLOG(3) << "RequestPrefetch " << var_name;
Y
Yancey1989 已提交
196 197 198 199
    auto var_desc = program_->Block(0).FindVar(var_name);
    framework::Scope* local_scope = &scope_->NewScope();
    auto* var = local_scope->FindVar(var_name);
    InitializeVariable(var, var_desc->GetType());
W
Wu Yi 已提交
200
    executor_->RunPreparedContext(prefetch_ctx_, scope_);
Y
Yancey1989 已提交
201 202

    SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
Q
qiaolongfei 已提交
203

X
Xin Pan 已提交
204 205
    responder_.Finish(reply, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
206 207 208 209
    status_ = FINISH;
  }

 protected:
Y
Yancey1989 已提交
210
  std::shared_ptr<VariableResponse> request_;
211 212 213 214
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  framework::Scope* scope_;
  framework::Executor* executor_;
  framework::ProgramDesc* program_;
Y
Yancey1989 已提交
215
  framework::ExecutorPrepareContext* prefetch_ctx_;
X
Xin Pan 已提交
216
  int i_;
217 218
};

T
typhoonzero 已提交
219
void AsyncGRPCServer::WaitClientGet(int count) {
220 221 222 223 224 225
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
226 227 228
  }
}

T
done  
typhoonzero 已提交
229
void AsyncGRPCServer::WaitServerReady() {
T
update  
typhoonzero 已提交
230
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
231
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
T
update  
typhoonzero 已提交
232 233
}

G
gongweibao 已提交
234
void AsyncGRPCServer::RunSyncUpdate() {
235
  ::grpc::ServerBuilder builder;
T
typhoonzero 已提交
236 237
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
                           &selected_port_);
G
gongweibao 已提交
238 239
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
240 241 242 243
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
244
  cq_prefetch_ = builder.AddCompletionQueue();
Y
Yancey 已提交
245

G
gongweibao 已提交
246
  server_ = builder.BuildAndStart();
T
typhoonzero 已提交
247 248
  LOG(INFO) << "Server listening on " << address_
            << " selected port: " << selected_port_;
G
gongweibao 已提交
249

X
Xin Pan 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
  std::function<void(int)> send_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1);
  std::function<void(int)> get_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
  std::function<void(int)> prefetch_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
                std::placeholders::_1);

  for (int i = 0; i < kSendReqsBufSize; ++i) {
    TryToRegisterNewSendOne(i);
  }
  for (int i = 0; i < kGetReqsBufSize; ++i) {
    TryToRegisterNewGetOne(i);
  }

  for (int i = 0; i < kNumHandleSendThreads; ++i) {
    t_sends_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_send_.get(), "cq_send", send_register)));
  }
  for (int i = 0; i < kNumHandleGetThreads; ++i) {
    t_gets_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_get_.get(), "cq_get", get_register)));
  }
G
gongweibao 已提交
275

T
typhoonzero 已提交
276
  // TODO(wuyi): Run these "HandleRequest" in thread pool
277 278 279
  t_prefetch_.reset(new std::thread(
      std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
                "cq_prefetch", prefetch_register)));
T
wip  
typhoonzero 已提交
280 281 282 283 284 285

  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
G
gongweibao 已提交
286 287
  // wait server
  server_->Wait();
X
Xin Pan 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
  for (int i = 0; i < kNumHandleSendThreads; ++i) {
    t_sends_[i]->join();
  }
  for (int i = 0; i < kNumHandleGetThreads; ++i) {
    t_gets_[i]->join();
  }
  {
    std::lock_guard<std::mutex> l(cq_mutex_);
    for (int i = 0; i < kSendReqsBufSize; ++i) {
      if (send_reqs_[i]) {
        delete send_reqs_[i];
        send_reqs_[i] = nullptr;
      }
    }
    for (int i = 0; i < kGetReqsBufSize; ++i) {
      if (get_reqs_[i]) {
        delete get_reqs_[i];
        get_reqs_[i] = nullptr;
      }
    }
  }
309
  t_prefetch_->join();
G
gongweibao 已提交
310 311 312 313 314 315
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
316
  cq_prefetch_->Shutdown();
G
gongweibao 已提交
317 318 319 320
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
T
typhoonzero 已提交
321
  is_shut_down_ = true;
G
gongweibao 已提交
322
  ShutdownQueue();
T
typhoonzero 已提交
323
  server_->Shutdown();
G
gongweibao 已提交
324 325
}

X
Xin Pan 已提交
326
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
G
gongweibao 已提交
327 328
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
329
    VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
330 331
    return;
  }
Q
qiaolongfei 已提交
332
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
X
Xin Pan 已提交
333 334
                                      scope_, &var_recv_queue_, dev_ctx_, i);
  send_reqs_[i] = static_cast<RequestBase*>(send);
Y
Yancey 已提交
335
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
336 337
}

X
Xin Pan 已提交
338
void AsyncGRPCServer::TryToRegisterNewGetOne(int i) {
G
gongweibao 已提交
339 340
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
341
    VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
G
gongweibao 已提交
342 343
    return;
  }
Q
qiaolongfei 已提交
344
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
X
Xin Pan 已提交
345 346
                                   dev_ctx_, &var_get_queue_, i);
  get_reqs_[i] = static_cast<RequestBase*>(get);
Y
Yancey 已提交
347
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
348 349
}

X
Xin Pan 已提交
350
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
351 352
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
353
    VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
354 355
    return;
  }
X
Xin Pan 已提交
356 357 358
  RequestPrefetch* prefetch = new RequestPrefetch(
      &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
      program_, prefetch_ctx_.get(), i);
359 360 361 362

  VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}

Y
Yancey 已提交
363
// FIXME(typhoonzero): change cq_name to enum.
X
Xin Pan 已提交
364 365 366
void AsyncGRPCServer::HandleRequest(
    ::grpc::ServerCompletionQueue* cq, const std::string& cq_name,
    std::function<void(int)> TryToRegisterNewOne) {
G
gongweibao 已提交
367 368
  void* tag = NULL;
  bool ok = false;
369

G
gongweibao 已提交
370
  while (true) {
Q
qiaolongfei 已提交
371
    VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
G
gongweibao 已提交
372
    if (!cq->Next(&tag, &ok)) {
T
typhoonzero 已提交
373
      LOG(INFO) << cq_name << " CompletionQueue shutdown!";
G
gongweibao 已提交
374 375
      break;
    }
Q
qiaolongfei 已提交
376
    VLOG(3) << "HandleRequest for " << cq_name << " get Next";
X
Xin Pan 已提交
377
    int i = static_cast<int>(reinterpret_cast<intptr_t>(tag));
Q
qiaolongfei 已提交
378

Q
qiaolongfei 已提交
379 380 381 382
    if (sync_mode_) {
      // FIXME(typhoonzero): de-couple the barriers with recv_op
      if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
      if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
Q
qiaolongfei 已提交
383
      VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
Q
qiaolongfei 已提交
384
    }
G
gongweibao 已提交
385

X
Xin Pan 已提交
386 387 388 389 390 391 392 393 394 395 396
    RequestBase* base = nullptr;
    {
      std::lock_guard<std::mutex> l(cq_mutex_);
      if (cq_name == "cq_get") {
        base = get_reqs_[i];
      } else if (cq_name == "cq_send") {
        base = send_reqs_[i];
      } else {
        CHECK(false);
      }
    }
G
gongweibao 已提交
397 398 399 400
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
401
    if (!ok) {
Q
qiaolongfei 已提交
402 403
      LOG(WARNING) << cq_name << " recv no regular event:argument name["
                   << base->GetReqName() << "]";
X
Xin Pan 已提交
404
      TryToRegisterNewOne(i);
G
gongweibao 已提交
405 406 407 408 409 410 411
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
Q
qiaolongfei 已提交
412
        VLOG(4) << cq_name << " PROCESS status:" << base->Status();
G
gongweibao 已提交
413 414 415
        break;
      }
      case FINISH: {
X
Xin Pan 已提交
416
        TryToRegisterNewOne(i);
Q
qiaolongfei 已提交
417
        VLOG(4) << cq_name << " FINISH status:" << base->Status();
G
gongweibao 已提交
418 419 420 421 422 423 424 425
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
426 427 428 429
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
430 431
}

T
typhoonzero 已提交
432
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
433
  {
T
typhoonzero 已提交
434 435
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
436
  }
T
typhoonzero 已提交
437
  barrier_condition_.notify_all();
G
gongweibao 已提交
438 439 440 441 442
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle