grpc_server.cc 14.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
16 17 18

#include <limits>
#include <string>
G
gongweibao 已提交
19

20
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
21 22 23 24

namespace paddle {
namespace operators {
namespace detail {
X
Xin Pan 已提交
25 26 27
namespace {
const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20;
X
Xin Pan 已提交
28
const int kNumHandlePrefetchThreads = 1;
X
Xin Pan 已提交
29
}  // namespace
G
gongweibao 已提交
30 31 32 33 34 35
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
36
  explicit RequestBase(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
37
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
38
                       const platform::DeviceContext* dev_ctx)
Q
qiaolongfei 已提交
39 40 41 42 43
      : service_(service),
        cq_(cq),
        sync_mode_(sync_mode),
        status_(PROCESS),
        dev_ctx_(dev_ctx) {
G
gongweibao 已提交
44 45
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
46 47 48 49 50
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
51 52 53 54
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
55 56

 protected:
57 58 59
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
Q
qiaolongfei 已提交
60
  const bool sync_mode_;
G
gongweibao 已提交
61
  CallStatus status_;
62
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
63 64 65 66
};

class RequestSend final : public RequestBase {
 public:
67
  explicit RequestSend(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
68
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
69
                       framework::Scope* scope, ReceivedQueue* queue,
X
Xin Pan 已提交
70
                       const platform::DeviceContext* dev_ctx, int req_id)
Q
qiaolongfei 已提交
71 72
      : RequestBase(service, cq, sync_mode, dev_ctx),
        queue_(queue),
X
Xin Pan 已提交
73
        responder_(&ctx_),
X
Xin Pan 已提交
74
        req_id_(req_id) {
Q
qiaolongfei 已提交
75
    if (sync_mode_) {
76
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
77
    } else {
78
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
79
    }
80
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
81 82
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
83
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
84 85 86 87
  }

  virtual ~RequestSend() {}

88
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
89

G
gongweibao 已提交
90
  virtual void Process() {
Q
qiaolongfei 已提交
91 92 93
    std::string var_name = GetReqName();
    VLOG(3) << "RequestSend " << var_name;
    queue_->Push(std::make_pair(var_name, request_));
94

G
gongweibao 已提交
95
    status_ = FINISH;
X
Xin Pan 已提交
96
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
97
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
98 99 100
  }

 protected:
X
Xin Pan 已提交
101
  sendrecv::VoidMessage reply_;
102 103
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
104
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
X
Xin Pan 已提交
105
  int req_id_;
G
gongweibao 已提交
106 107 108 109
};

class RequestGet final : public RequestBase {
 public:
110
  explicit RequestGet(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
111
                      ::grpc::ServerCompletionQueue* cq, bool sync_mode,
112
                      framework::Scope* scope,
T
typhoonzero 已提交
113
                      const platform::DeviceContext* dev_ctx,
X
Xin Pan 已提交
114 115
                      framework::BlockingQueue<MessageWithName>* queue,
                      int req_id)
Q
qiaolongfei 已提交
116
      : RequestBase(service, cq, sync_mode, dev_ctx),
Y
Yancey1989 已提交
117 118
        responder_(&ctx_),
        scope_(scope),
X
Xin Pan 已提交
119
        queue_(queue),
X
Xin Pan 已提交
120
        req_id_(req_id) {
Q
qiaolongfei 已提交
121
    auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
122 123
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
X
Xin Pan 已提交
124
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
125 126 127 128
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
129 130
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
131 132 133
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
Q
qiaolongfei 已提交
134
    VLOG(3) << "RequestGet " << var_name;
G
gongweibao 已提交
135
    auto* var = scope_->FindVar(var_name);
136

137
    if (var_name != FETCH_BARRIER_MESSAGE) {
X
Xin Pan 已提交
138
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
139
    }
140

G
gongweibao 已提交
141
    status_ = FINISH;
X
Xin Pan 已提交
142
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
143
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
144 145 146 147 148 149

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
150 151 152 153
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
154
  ::grpc::ByteBuffer reply_;
155
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
156
  framework::Scope* scope_;
T
typhoonzero 已提交
157
  framework::BlockingQueue<MessageWithName>* queue_;
X
Xin Pan 已提交
158
  int req_id_;
G
gongweibao 已提交
159 160
};

161 162 163
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
164
                           ::grpc::ServerCompletionQueue* cq, bool sync_mode,
165 166 167
                           framework::Scope* scope,
                           const platform::DeviceContext* dev_ctx,
                           framework::Executor* executor,
Y
Yancey1989 已提交
168
                           framework::ProgramDesc* program,
X
Xin Pan 已提交
169
                           framework::ExecutorPrepareContext* prefetch_ctx,
X
Xin Pan 已提交
170
                           int req_id)
Q
qiaolongfei 已提交
171
      : RequestBase(service, cq, sync_mode, dev_ctx),
172 173 174 175
        responder_(&ctx_),
        scope_(scope),
        executor_(executor),
        program_(program),
X
Xin Pan 已提交
176
        prefetch_ctx_(prefetch_ctx),
X
Xin Pan 已提交
177
        req_id_(req_id) {
Q
qiaolongfei 已提交
178
    if (sync_mode_) {
179
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
180
    } else {
181
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
182
    }
183
    int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
184 185 186
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
187 188 189 190
  }

  virtual ~RequestPrefetch() {}

Y
Yancey1989 已提交
191
  virtual std::string GetReqName() { return request_->Varname(); }
192 193 194 195

  virtual void Process() {
    // prefetch process...

Y
Yancey1989 已提交
196
    std::string var_name = request_->OutVarname();
Q
qiaolongfei 已提交
197
    VLOG(3) << "RequestPrefetch " << var_name;
Y
Yancey1989 已提交
198 199 200 201
    auto var_desc = program_->Block(0).FindVar(var_name);
    framework::Scope* local_scope = &scope_->NewScope();
    auto* var = local_scope->FindVar(var_name);
    InitializeVariable(var, var_desc->GetType());
W
Wu Yi 已提交
202
    executor_->RunPreparedContext(prefetch_ctx_, scope_);
Y
Yancey1989 已提交
203

X
Xin Pan 已提交
204
    SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
Q
qiaolongfei 已提交
205

206
    status_ = FINISH;
X
Xin Pan 已提交
207 208
    responder_.Finish(reply_, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
209 210 211
  }

 protected:
Y
Yancey1989 已提交
212
  std::shared_ptr<VariableResponse> request_;
X
Xin Pan 已提交
213
  ::grpc::ByteBuffer reply_;
214 215 216 217
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  framework::Scope* scope_;
  framework::Executor* executor_;
  framework::ProgramDesc* program_;
Y
Yancey1989 已提交
218
  framework::ExecutorPrepareContext* prefetch_ctx_;
X
Xin Pan 已提交
219
  int req_id_;
220 221
};

T
typhoonzero 已提交
222
void AsyncGRPCServer::WaitClientGet(int count) {
223 224 225 226 227 228
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
229 230 231
  }
}

T
done  
typhoonzero 已提交
232
void AsyncGRPCServer::WaitServerReady() {
T
update  
typhoonzero 已提交
233
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
234
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
T
update  
typhoonzero 已提交
235 236
}

G
gongweibao 已提交
237
void AsyncGRPCServer::RunSyncUpdate() {
238
  ::grpc::ServerBuilder builder;
T
typhoonzero 已提交
239 240
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
                           &selected_port_);
G
gongweibao 已提交
241 242
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
243 244 245 246
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
247
  cq_prefetch_ = builder.AddCompletionQueue();
Y
Yancey 已提交
248

G
gongweibao 已提交
249
  server_ = builder.BuildAndStart();
T
typhoonzero 已提交
250 251
  LOG(INFO) << "Server listening on " << address_
            << " selected port: " << selected_port_;
G
gongweibao 已提交
252

X
Xin Pan 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266
  std::function<void(int)> send_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1);
  std::function<void(int)> get_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
  std::function<void(int)> prefetch_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
                std::placeholders::_1);

  for (int i = 0; i < kSendReqsBufSize; ++i) {
    TryToRegisterNewSendOne(i);
  }
  for (int i = 0; i < kGetReqsBufSize; ++i) {
    TryToRegisterNewGetOne(i);
  }
X
Xin Pan 已提交
267 268 269
  for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
    TryToRegisterNewPrefetchOne(i);
  }
X
Xin Pan 已提交
270 271 272 273 274 275 276 277 278 279 280

  for (int i = 0; i < kNumHandleSendThreads; ++i) {
    t_sends_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_send_.get(), "cq_send", send_register)));
  }
  for (int i = 0; i < kNumHandleGetThreads; ++i) {
    t_gets_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_get_.get(), "cq_get", get_register)));
  }
X
Xin Pan 已提交
281 282 283 284 285
  for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
    t_prefetchs_.emplace_back(new std::thread(
        std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
                  "cq_prefetch", prefetch_register)));
  }
T
wip  
typhoonzero 已提交
286 287 288 289 290
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
G
gongweibao 已提交
291 292
  // wait server
  server_->Wait();
X
Xin Pan 已提交
293 294 295 296 297 298
  for (int i = 0; i < kNumHandleSendThreads; ++i) {
    t_sends_[i]->join();
  }
  for (int i = 0; i < kNumHandleGetThreads; ++i) {
    t_gets_[i]->join();
  }
X
Xin Pan 已提交
299 300 301
  for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
    t_prefetchs_[i]->join();
  }
G
gongweibao 已提交
302 303 304 305 306 307
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
308
  cq_prefetch_->Shutdown();
G
gongweibao 已提交
309 310 311 312
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
T
typhoonzero 已提交
313
  is_shut_down_ = true;
G
gongweibao 已提交
314
  ShutdownQueue();
T
typhoonzero 已提交
315
  server_->Shutdown();
G
gongweibao 已提交
316 317
}

X
Xin Pan 已提交
318
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
G
gongweibao 已提交
319 320
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
321
    VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
322 323
    return;
  }
Q
qiaolongfei 已提交
324
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
X
Xin Pan 已提交
325 326
                                      scope_, &var_recv_queue_, dev_ctx_, i);
  send_reqs_[i] = static_cast<RequestBase*>(send);
Y
Yancey 已提交
327
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
328 329
}

X
Xin Pan 已提交
330
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
G
gongweibao 已提交
331 332
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
333
    VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
G
gongweibao 已提交
334 335
    return;
  }
Q
qiaolongfei 已提交
336
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
X
Xin Pan 已提交
337 338
                                   dev_ctx_, &var_get_queue_, req_id);
  get_reqs_[req_id] = static_cast<RequestBase*>(get);
Y
Yancey 已提交
339
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
340 341
}

X
Xin Pan 已提交
342
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
343 344
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
345
    VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
346 347
    return;
  }
X
Xin Pan 已提交
348 349
  RequestPrefetch* prefetch = new RequestPrefetch(
      &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
X
Xin Pan 已提交
350
      program_, prefetch_ctx_.get(), req_id);
X
Xin Pan 已提交
351
  prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
352 353 354 355

  VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}

Y
Yancey 已提交
356
// FIXME(typhoonzero): change cq_name to enum.
X
Xin Pan 已提交
357 358 359
void AsyncGRPCServer::HandleRequest(
    ::grpc::ServerCompletionQueue* cq, const std::string& cq_name,
    std::function<void(int)> TryToRegisterNewOne) {
G
gongweibao 已提交
360 361
  void* tag = NULL;
  bool ok = false;
362

G
gongweibao 已提交
363
  while (true) {
Q
qiaolongfei 已提交
364
    VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
G
gongweibao 已提交
365
    if (!cq->Next(&tag, &ok)) {
T
typhoonzero 已提交
366
      LOG(INFO) << cq_name << " CompletionQueue shutdown!";
G
gongweibao 已提交
367 368
      break;
    }
Q
qiaolongfei 已提交
369
    VLOG(3) << "HandleRequest for " << cq_name << " get Next";
X
Xin Pan 已提交
370
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
Q
qiaolongfei 已提交
371

Q
qiaolongfei 已提交
372 373 374 375
    if (sync_mode_) {
      // FIXME(typhoonzero): de-couple the barriers with recv_op
      if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
      if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
Q
qiaolongfei 已提交
376
      VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
Q
qiaolongfei 已提交
377
    }
G
gongweibao 已提交
378

X
Xin Pan 已提交
379 380 381 382
    RequestBase* base = nullptr;
    {
      std::lock_guard<std::mutex> l(cq_mutex_);
      if (cq_name == "cq_get") {
X
Xin Pan 已提交
383
        base = get_reqs_[req_id];
X
Xin Pan 已提交
384
      } else if (cq_name == "cq_send") {
X
Xin Pan 已提交
385
        base = send_reqs_[req_id];
X
Xin Pan 已提交
386 387
      } else if (cq_name == "cq_prefetch") {
        base = prefetch_reqs_[req_id];
X
Xin Pan 已提交
388 389
      }
    }
G
gongweibao 已提交
390 391 392 393
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
394
    if (!ok) {
Q
qiaolongfei 已提交
395 396
      LOG(WARNING) << cq_name << " recv no regular event:argument name["
                   << base->GetReqName() << "]";
X
Xin Pan 已提交
397
      TryToRegisterNewOne(req_id);
G
gongweibao 已提交
398 399 400 401 402 403 404
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
Q
qiaolongfei 已提交
405
        VLOG(4) << cq_name << " PROCESS status:" << base->Status();
G
gongweibao 已提交
406 407 408
        break;
      }
      case FINISH: {
X
Xin Pan 已提交
409
        TryToRegisterNewOne(req_id);
Q
qiaolongfei 已提交
410
        VLOG(4) << cq_name << " FINISH status:" << base->Status();
G
gongweibao 已提交
411 412 413 414 415 416 417 418
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
419 420 421 422
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
423 424
}

T
typhoonzero 已提交
425
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
426
  {
T
typhoonzero 已提交
427 428
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
429
  }
T
typhoonzero 已提交
430
  barrier_condition_.notify_all();
G
gongweibao 已提交
431 432 433 434 435
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle