grpc_server.cc 13.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
16 17 18

#include <limits>
#include <string>
G
gongweibao 已提交
19

20
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
21 22 23 24

namespace paddle {
namespace operators {
namespace detail {
X
Xin Pan 已提交
25 26 27 28
namespace {
const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20;
}  // namespace
G
gongweibao 已提交
29 30 31 32 33 34
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
35
  explicit RequestBase(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
36
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
37
                       const platform::DeviceContext* dev_ctx)
Q
qiaolongfei 已提交
38 39 40 41 42
      : service_(service),
        cq_(cq),
        sync_mode_(sync_mode),
        status_(PROCESS),
        dev_ctx_(dev_ctx) {
G
gongweibao 已提交
43 44
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
45 46 47 48 49
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
50 51 52 53
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
54 55

 protected:
56 57 58
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
Q
qiaolongfei 已提交
59
  const bool sync_mode_;
G
gongweibao 已提交
60
  CallStatus status_;
61
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
62 63 64 65
};

class RequestSend final : public RequestBase {
 public:
66
  explicit RequestSend(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
67
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
68
                       framework::Scope* scope, ReceivedQueue* queue,
X
Xin Pan 已提交
69
                       const platform::DeviceContext* dev_ctx, int req_id)
Q
qiaolongfei 已提交
70 71
      : RequestBase(service, cq, sync_mode, dev_ctx),
        queue_(queue),
X
Xin Pan 已提交
72
        responder_(&ctx_),
X
Xin Pan 已提交
73
        req_id_(req_id) {
Q
qiaolongfei 已提交
74
    if (sync_mode_) {
75
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
76
    } else {
77
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
78
    }
79
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
80 81
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
82
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
83 84 85 86
  }

  virtual ~RequestSend() {}

87
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
88

G
gongweibao 已提交
89
  virtual void Process() {
Q
qiaolongfei 已提交
90 91 92
    std::string var_name = GetReqName();
    VLOG(3) << "RequestSend " << var_name;
    queue_->Push(std::make_pair(var_name, request_));
93

G
gongweibao 已提交
94
    status_ = FINISH;
X
Xin Pan 已提交
95
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
96
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
97 98 99
  }

 protected:
X
Xin Pan 已提交
100
  sendrecv::VoidMessage reply_;
101 102
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
103
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
X
Xin Pan 已提交
104
  int req_id_;
G
gongweibao 已提交
105 106 107 108
};

class RequestGet final : public RequestBase {
 public:
109
  explicit RequestGet(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
110
                      ::grpc::ServerCompletionQueue* cq, bool sync_mode,
111
                      framework::Scope* scope,
T
typhoonzero 已提交
112
                      const platform::DeviceContext* dev_ctx,
X
Xin Pan 已提交
113 114
                      framework::BlockingQueue<MessageWithName>* queue,
                      int req_id)
Q
qiaolongfei 已提交
115
      : RequestBase(service, cq, sync_mode, dev_ctx),
Y
Yancey1989 已提交
116 117
        responder_(&ctx_),
        scope_(scope),
X
Xin Pan 已提交
118
        queue_(queue),
X
Xin Pan 已提交
119
        req_id_(req_id) {
Q
qiaolongfei 已提交
120
    auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
121 122
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
X
Xin Pan 已提交
123
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
124 125 126 127
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
128 129
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
130 131 132
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
Q
qiaolongfei 已提交
133
    VLOG(3) << "RequestGet " << var_name;
G
gongweibao 已提交
134
    auto* var = scope_->FindVar(var_name);
135

136
    if (var_name != FETCH_BARRIER_MESSAGE) {
X
Xin Pan 已提交
137
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
138
    }
139

G
gongweibao 已提交
140
    status_ = FINISH;
X
Xin Pan 已提交
141
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
142
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
143 144 145 146 147 148

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
149 150 151 152
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
153
  ::grpc::ByteBuffer reply_;
154
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
155
  framework::Scope* scope_;
T
typhoonzero 已提交
156
  framework::BlockingQueue<MessageWithName>* queue_;
X
Xin Pan 已提交
157
  int req_id_;
G
gongweibao 已提交
158 159
};

160 161 162
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
163
                           ::grpc::ServerCompletionQueue* cq, bool sync_mode,
164 165 166
                           framework::Scope* scope,
                           const platform::DeviceContext* dev_ctx,
                           framework::Executor* executor,
Y
Yancey1989 已提交
167
                           framework::ProgramDesc* program,
X
Xin Pan 已提交
168
                           framework::ExecutorPrepareContext* prefetch_ctx,
X
Xin Pan 已提交
169
                           int req_id)
Q
qiaolongfei 已提交
170
      : RequestBase(service, cq, sync_mode, dev_ctx),
171 172 173 174
        responder_(&ctx_),
        scope_(scope),
        executor_(executor),
        program_(program),
X
Xin Pan 已提交
175
        prefetch_ctx_(prefetch_ctx),
X
Xin Pan 已提交
176
        req_id_(req_id) {
Q
qiaolongfei 已提交
177
    if (sync_mode_) {
178
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
179
    } else {
180
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
181
    }
182
    int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
Y
Yancey1989 已提交
183 184
    service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
                                cq_, cq_, this);
185 186 187 188
  }

  virtual ~RequestPrefetch() {}

Y
Yancey1989 已提交
189
  virtual std::string GetReqName() { return request_->Varname(); }
190 191 192

  virtual void Process() {
    // prefetch process...
Y
Yancey1989 已提交
193
    ::grpc::ByteBuffer reply;
194

Y
Yancey1989 已提交
195
    std::string var_name = request_->OutVarname();
Q
qiaolongfei 已提交
196
    VLOG(3) << "RequestPrefetch " << var_name;
Y
Yancey1989 已提交
197 198 199 200
    auto var_desc = program_->Block(0).FindVar(var_name);
    framework::Scope* local_scope = &scope_->NewScope();
    auto* var = local_scope->FindVar(var_name);
    InitializeVariable(var, var_desc->GetType());
W
Wu Yi 已提交
201
    executor_->RunPreparedContext(prefetch_ctx_, scope_);
Y
Yancey1989 已提交
202 203

    SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
Q
qiaolongfei 已提交
204

X
Xin Pan 已提交
205
    responder_.Finish(reply, ::grpc::Status::OK,
X
Xin Pan 已提交
206
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
207 208 209 210
    status_ = FINISH;
  }

 protected:
Y
Yancey1989 已提交
211
  std::shared_ptr<VariableResponse> request_;
212 213 214 215
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  framework::Scope* scope_;
  framework::Executor* executor_;
  framework::ProgramDesc* program_;
Y
Yancey1989 已提交
216
  framework::ExecutorPrepareContext* prefetch_ctx_;
X
Xin Pan 已提交
217
  int req_id_;
218 219
};

T
typhoonzero 已提交
220
void AsyncGRPCServer::WaitClientGet(int count) {
221 222 223 224 225 226
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
227 228 229
  }
}

T
done  
typhoonzero 已提交
230
void AsyncGRPCServer::WaitServerReady() {
T
update  
typhoonzero 已提交
231
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
232
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
T
update  
typhoonzero 已提交
233 234
}

G
gongweibao 已提交
235
void AsyncGRPCServer::RunSyncUpdate() {
236
  ::grpc::ServerBuilder builder;
T
typhoonzero 已提交
237 238
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
                           &selected_port_);
G
gongweibao 已提交
239 240
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
241 242 243 244
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
245
  cq_prefetch_ = builder.AddCompletionQueue();
Y
Yancey 已提交
246

G
gongweibao 已提交
247
  server_ = builder.BuildAndStart();
T
typhoonzero 已提交
248 249
  LOG(INFO) << "Server listening on " << address_
            << " selected port: " << selected_port_;
G
gongweibao 已提交
250

X
Xin Pan 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
  std::function<void(int)> send_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1);
  std::function<void(int)> get_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
  std::function<void(int)> prefetch_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
                std::placeholders::_1);

  for (int i = 0; i < kSendReqsBufSize; ++i) {
    TryToRegisterNewSendOne(i);
  }
  for (int i = 0; i < kGetReqsBufSize; ++i) {
    TryToRegisterNewGetOne(i);
  }

  for (int i = 0; i < kNumHandleSendThreads; ++i) {
    t_sends_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_send_.get(), "cq_send", send_register)));
  }
  for (int i = 0; i < kNumHandleGetThreads; ++i) {
    t_gets_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_get_.get(), "cq_get", get_register)));
  }
G
gongweibao 已提交
276

T
typhoonzero 已提交
277
  // TODO(wuyi): Run these "HandleRequest" in thread pool
278 279 280
  t_prefetch_.reset(new std::thread(
      std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
                "cq_prefetch", prefetch_register)));
T
wip  
typhoonzero 已提交
281 282 283 284 285 286

  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
G
gongweibao 已提交
287 288
  // wait server
  server_->Wait();
X
Xin Pan 已提交
289 290 291 292 293 294
  for (int i = 0; i < kNumHandleSendThreads; ++i) {
    t_sends_[i]->join();
  }
  for (int i = 0; i < kNumHandleGetThreads; ++i) {
    t_gets_[i]->join();
  }
295
  t_prefetch_->join();
G
gongweibao 已提交
296 297 298 299 300 301
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
302
  cq_prefetch_->Shutdown();
G
gongweibao 已提交
303 304 305 306
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
T
typhoonzero 已提交
307
  is_shut_down_ = true;
G
gongweibao 已提交
308
  ShutdownQueue();
T
typhoonzero 已提交
309
  server_->Shutdown();
G
gongweibao 已提交
310 311
}

X
Xin Pan 已提交
312
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
G
gongweibao 已提交
313 314
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
315
    VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
316 317
    return;
  }
Q
qiaolongfei 已提交
318
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
X
Xin Pan 已提交
319 320
                                      scope_, &var_recv_queue_, dev_ctx_, i);
  send_reqs_[i] = static_cast<RequestBase*>(send);
Y
Yancey 已提交
321
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
322 323
}

X
Xin Pan 已提交
324
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
G
gongweibao 已提交
325 326
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
327
    VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
G
gongweibao 已提交
328 329
    return;
  }
Q
qiaolongfei 已提交
330
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
X
Xin Pan 已提交
331 332
                                   dev_ctx_, &var_get_queue_, req_id);
  get_reqs_[req_id] = static_cast<RequestBase*>(get);
Y
Yancey 已提交
333
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
334 335
}

X
Xin Pan 已提交
336
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
337 338
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
339
    VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
340 341
    return;
  }
X
Xin Pan 已提交
342 343
  RequestPrefetch* prefetch = new RequestPrefetch(
      &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
X
Xin Pan 已提交
344
      program_, prefetch_ctx_.get(), req_id);
345 346 347 348

  VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}

Y
Yancey 已提交
349
// FIXME(typhoonzero): change cq_name to enum.
X
Xin Pan 已提交
350 351 352
void AsyncGRPCServer::HandleRequest(
    ::grpc::ServerCompletionQueue* cq, const std::string& cq_name,
    std::function<void(int)> TryToRegisterNewOne) {
G
gongweibao 已提交
353 354
  void* tag = NULL;
  bool ok = false;
355

G
gongweibao 已提交
356
  while (true) {
Q
qiaolongfei 已提交
357
    VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
G
gongweibao 已提交
358
    if (!cq->Next(&tag, &ok)) {
T
typhoonzero 已提交
359
      LOG(INFO) << cq_name << " CompletionQueue shutdown!";
G
gongweibao 已提交
360 361
      break;
    }
Q
qiaolongfei 已提交
362
    VLOG(3) << "HandleRequest for " << cq_name << " get Next";
X
Xin Pan 已提交
363
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
Q
qiaolongfei 已提交
364

Q
qiaolongfei 已提交
365 366 367 368
    if (sync_mode_) {
      // FIXME(typhoonzero): de-couple the barriers with recv_op
      if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
      if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
Q
qiaolongfei 已提交
369
      VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
Q
qiaolongfei 已提交
370
    }
G
gongweibao 已提交
371

X
Xin Pan 已提交
372 373 374 375
    RequestBase* base = nullptr;
    {
      std::lock_guard<std::mutex> l(cq_mutex_);
      if (cq_name == "cq_get") {
X
Xin Pan 已提交
376
        base = get_reqs_[req_id];
X
Xin Pan 已提交
377
      } else if (cq_name == "cq_send") {
X
Xin Pan 已提交
378
        base = send_reqs_[req_id];
X
Xin Pan 已提交
379 380 381 382
      } else {
        CHECK(false);
      }
    }
G
gongweibao 已提交
383 384 385 386
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
387
    if (!ok) {
Q
qiaolongfei 已提交
388 389
      LOG(WARNING) << cq_name << " recv no regular event:argument name["
                   << base->GetReqName() << "]";
X
Xin Pan 已提交
390
      TryToRegisterNewOne(req_id);
G
gongweibao 已提交
391 392 393 394 395 396 397
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
Q
qiaolongfei 已提交
398
        VLOG(4) << cq_name << " PROCESS status:" << base->Status();
G
gongweibao 已提交
399 400 401
        break;
      }
      case FINISH: {
X
Xin Pan 已提交
402
        TryToRegisterNewOne(req_id);
Q
qiaolongfei 已提交
403
        VLOG(4) << cq_name << " FINISH status:" << base->Status();
G
gongweibao 已提交
404 405 406 407 408 409 410 411
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
412 413 414 415
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
416 417
}

T
typhoonzero 已提交
418
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
419
  {
T
typhoonzero 已提交
420 421
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
422
  }
T
typhoonzero 已提交
423
  barrier_condition_.notify_all();
G
gongweibao 已提交
424 425 426 427 428
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle