grpc_server.cc 14.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
16 17 18

#include <limits>
#include <string>
G
gongweibao 已提交
19

20
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
21

X
Xin Pan 已提交
22 23 24 25 26 27 28
DEFINE_int32(rpc_server_handle_send_threads, 20,
             "Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
             "Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
             "Number of threads used to handle prefetch at rpc server.");

G
gongweibao 已提交
29 30 31 32 33 34 35 36 37
namespace paddle {
namespace operators {
namespace detail {
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
38
  explicit RequestBase(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
39
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
40
                       const platform::DeviceContext* dev_ctx)
Q
qiaolongfei 已提交
41 42 43 44 45
      : service_(service),
        cq_(cq),
        sync_mode_(sync_mode),
        status_(PROCESS),
        dev_ctx_(dev_ctx) {
G
gongweibao 已提交
46 47
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
48 49 50 51 52
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
53 54 55 56
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
57 58

 protected:
59 60 61
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
Q
qiaolongfei 已提交
62
  const bool sync_mode_;
G
gongweibao 已提交
63
  CallStatus status_;
64
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
65 66 67 68
};

class RequestSend final : public RequestBase {
 public:
69
  explicit RequestSend(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
70
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
71
                       framework::Scope* scope, ReceivedQueue* queue,
X
Xin Pan 已提交
72
                       const platform::DeviceContext* dev_ctx, int req_id)
Q
qiaolongfei 已提交
73 74
      : RequestBase(service, cq, sync_mode, dev_ctx),
        queue_(queue),
X
Xin Pan 已提交
75
        responder_(&ctx_),
X
Xin Pan 已提交
76
        req_id_(req_id) {
Q
qiaolongfei 已提交
77
    if (sync_mode_) {
78
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
79
    } else {
80
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
81
    }
82
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
83 84
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
85
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
86 87 88 89
  }

  virtual ~RequestSend() {}

90
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
91

G
gongweibao 已提交
92
  virtual void Process() {
Q
qiaolongfei 已提交
93 94 95
    std::string var_name = GetReqName();
    VLOG(3) << "RequestSend " << var_name;
    queue_->Push(std::make_pair(var_name, request_));
96

G
gongweibao 已提交
97
    status_ = FINISH;
X
Xin Pan 已提交
98
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
99
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
100 101 102
  }

 protected:
X
Xin Pan 已提交
103
  sendrecv::VoidMessage reply_;
104 105
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
106
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
X
Xin Pan 已提交
107
  int req_id_;
G
gongweibao 已提交
108 109 110 111
};

class RequestGet final : public RequestBase {
 public:
112
  explicit RequestGet(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
113
                      ::grpc::ServerCompletionQueue* cq, bool sync_mode,
114
                      framework::Scope* scope,
T
typhoonzero 已提交
115
                      const platform::DeviceContext* dev_ctx,
X
Xin Pan 已提交
116 117
                      framework::BlockingQueue<MessageWithName>* queue,
                      int req_id)
Q
qiaolongfei 已提交
118
      : RequestBase(service, cq, sync_mode, dev_ctx),
Y
Yancey1989 已提交
119 120
        responder_(&ctx_),
        scope_(scope),
X
Xin Pan 已提交
121
        queue_(queue),
X
Xin Pan 已提交
122
        req_id_(req_id) {
Q
qiaolongfei 已提交
123
    auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
124 125
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
X
Xin Pan 已提交
126
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
127 128 129 130
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
131 132
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
133 134 135
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
Q
qiaolongfei 已提交
136
    VLOG(3) << "RequestGet " << var_name;
G
gongweibao 已提交
137
    auto* var = scope_->FindVar(var_name);
138

139
    if (var_name != FETCH_BARRIER_MESSAGE) {
X
Xin Pan 已提交
140
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
141
    }
142

G
gongweibao 已提交
143
    status_ = FINISH;
X
Xin Pan 已提交
144
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
145
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
146 147 148 149 150 151

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
152 153 154 155
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
156
  ::grpc::ByteBuffer reply_;
157
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
158
  framework::Scope* scope_;
T
typhoonzero 已提交
159
  framework::BlockingQueue<MessageWithName>* queue_;
X
Xin Pan 已提交
160
  int req_id_;
G
gongweibao 已提交
161 162
};

163 164 165
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
166
                           ::grpc::ServerCompletionQueue* cq, bool sync_mode,
167 168 169
                           framework::Scope* scope,
                           const platform::DeviceContext* dev_ctx,
                           framework::Executor* executor,
Y
Yancey1989 已提交
170
                           framework::ProgramDesc* program,
X
Xin Pan 已提交
171
                           framework::ExecutorPrepareContext* prefetch_ctx,
X
Xin Pan 已提交
172
                           int req_id)
Q
qiaolongfei 已提交
173
      : RequestBase(service, cq, sync_mode, dev_ctx),
174 175 176 177
        responder_(&ctx_),
        scope_(scope),
        executor_(executor),
        program_(program),
X
Xin Pan 已提交
178
        prefetch_ctx_(prefetch_ctx),
X
Xin Pan 已提交
179
        req_id_(req_id) {
Q
qiaolongfei 已提交
180
    if (sync_mode_) {
181
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
182
    } else {
183
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
184
    }
185
    int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
186 187 188
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
189 190 191 192
  }

  virtual ~RequestPrefetch() {}

Y
Yancey1989 已提交
193
  virtual std::string GetReqName() { return request_->Varname(); }
194 195 196 197

  virtual void Process() {
    // prefetch process...

Y
Yancey1989 已提交
198
    std::string var_name = request_->OutVarname();
Q
qiaolongfei 已提交
199
    VLOG(3) << "RequestPrefetch " << var_name;
Y
Yancey1989 已提交
200 201 202 203
    auto var_desc = program_->Block(0).FindVar(var_name);
    framework::Scope* local_scope = &scope_->NewScope();
    auto* var = local_scope->FindVar(var_name);
    InitializeVariable(var, var_desc->GetType());
W
Wu Yi 已提交
204
    executor_->RunPreparedContext(prefetch_ctx_, scope_);
Y
Yancey1989 已提交
205

X
Xin Pan 已提交
206
    SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
Q
qiaolongfei 已提交
207

208
    status_ = FINISH;
X
Xin Pan 已提交
209 210
    responder_.Finish(reply_, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
211 212 213
  }

 protected:
Y
Yancey1989 已提交
214
  std::shared_ptr<VariableResponse> request_;
X
Xin Pan 已提交
215
  ::grpc::ByteBuffer reply_;
216 217 218 219
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  framework::Scope* scope_;
  framework::Executor* executor_;
  framework::ProgramDesc* program_;
Y
Yancey1989 已提交
220
  framework::ExecutorPrepareContext* prefetch_ctx_;
X
Xin Pan 已提交
221
  int req_id_;
222 223
};

T
typhoonzero 已提交
224
void AsyncGRPCServer::WaitClientGet(int count) {
225 226 227 228 229 230
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
231 232 233
  }
}

T
done  
typhoonzero 已提交
234
void AsyncGRPCServer::WaitServerReady() {
T
update  
typhoonzero 已提交
235
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
236
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
T
update  
typhoonzero 已提交
237 238
}

G
gongweibao 已提交
239
void AsyncGRPCServer::RunSyncUpdate() {
240
  ::grpc::ServerBuilder builder;
T
typhoonzero 已提交
241 242
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
                           &selected_port_);
G
gongweibao 已提交
243 244
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
245 246 247 248
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
249
  cq_prefetch_ = builder.AddCompletionQueue();
Y
Yancey 已提交
250

G
gongweibao 已提交
251
  server_ = builder.BuildAndStart();
T
typhoonzero 已提交
252 253
  LOG(INFO) << "Server listening on " << address_
            << " selected port: " << selected_port_;
G
gongweibao 已提交
254

X
Xin Pan 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268
  std::function<void(int)> send_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1);
  std::function<void(int)> get_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
  std::function<void(int)> prefetch_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
                std::placeholders::_1);

  for (int i = 0; i < kSendReqsBufSize; ++i) {
    TryToRegisterNewSendOne(i);
  }
  for (int i = 0; i < kGetReqsBufSize; ++i) {
    TryToRegisterNewGetOne(i);
  }
X
Xin Pan 已提交
269 270 271
  for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
    TryToRegisterNewPrefetchOne(i);
  }
X
Xin Pan 已提交
272

X
Xin Pan 已提交
273
  for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
X
Xin Pan 已提交
274 275 276 277
    t_sends_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_send_.get(), "cq_send", send_register)));
  }
X
Xin Pan 已提交
278
  for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
X
Xin Pan 已提交
279 280 281 282
    t_gets_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_get_.get(), "cq_get", get_register)));
  }
X
Xin Pan 已提交
283
  for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
X
Xin Pan 已提交
284 285 286 287
    t_prefetchs_.emplace_back(new std::thread(
        std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
                  "cq_prefetch", prefetch_register)));
  }
T
wip  
typhoonzero 已提交
288 289 290 291 292
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
G
gongweibao 已提交
293 294
  // wait server
  server_->Wait();
X
Xin Pan 已提交
295
  for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
X
Xin Pan 已提交
296 297
    t_sends_[i]->join();
  }
X
Xin Pan 已提交
298
  for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
X
Xin Pan 已提交
299 300
    t_gets_[i]->join();
  }
X
Xin Pan 已提交
301
  for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
X
Xin Pan 已提交
302 303
    t_prefetchs_[i]->join();
  }
G
gongweibao 已提交
304 305 306 307 308 309
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
310
  cq_prefetch_->Shutdown();
G
gongweibao 已提交
311 312 313 314
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
T
typhoonzero 已提交
315
  is_shut_down_ = true;
G
gongweibao 已提交
316
  ShutdownQueue();
T
typhoonzero 已提交
317
  server_->Shutdown();
G
gongweibao 已提交
318 319
}

X
Xin Pan 已提交
320
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
G
gongweibao 已提交
321 322
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
323
    VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
324 325
    return;
  }
Q
qiaolongfei 已提交
326
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
X
Xin Pan 已提交
327 328
                                      scope_, &var_recv_queue_, dev_ctx_, i);
  send_reqs_[i] = static_cast<RequestBase*>(send);
Y
Yancey 已提交
329
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
330 331
}

X
Xin Pan 已提交
332
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
G
gongweibao 已提交
333 334
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
335
    VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
G
gongweibao 已提交
336 337
    return;
  }
Q
qiaolongfei 已提交
338
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
X
Xin Pan 已提交
339 340
                                   dev_ctx_, &var_get_queue_, req_id);
  get_reqs_[req_id] = static_cast<RequestBase*>(get);
Y
Yancey 已提交
341
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
342 343
}

X
Xin Pan 已提交
344
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
345 346
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
347
    VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
348 349
    return;
  }
X
Xin Pan 已提交
350 351
  RequestPrefetch* prefetch = new RequestPrefetch(
      &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
X
Xin Pan 已提交
352
      program_, prefetch_ctx_.get(), req_id);
X
Xin Pan 已提交
353
  prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
354 355 356 357

  VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}

Y
Yancey 已提交
358
// FIXME(typhoonzero): change cq_name to enum.
X
Xin Pan 已提交
359 360 361
void AsyncGRPCServer::HandleRequest(
    ::grpc::ServerCompletionQueue* cq, const std::string& cq_name,
    std::function<void(int)> TryToRegisterNewOne) {
G
gongweibao 已提交
362 363
  void* tag = NULL;
  bool ok = false;
364

G
gongweibao 已提交
365
  while (true) {
Q
qiaolongfei 已提交
366
    VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
G
gongweibao 已提交
367
    if (!cq->Next(&tag, &ok)) {
T
typhoonzero 已提交
368
      LOG(INFO) << cq_name << " CompletionQueue shutdown!";
G
gongweibao 已提交
369 370
      break;
    }
Q
qiaolongfei 已提交
371
    VLOG(3) << "HandleRequest for " << cq_name << " get Next";
X
Xin Pan 已提交
372
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
Q
qiaolongfei 已提交
373

Q
qiaolongfei 已提交
374 375 376 377
    if (sync_mode_) {
      // FIXME(typhoonzero): de-couple the barriers with recv_op
      if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
      if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
Q
qiaolongfei 已提交
378
      VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
Q
qiaolongfei 已提交
379
    }
G
gongweibao 已提交
380

X
Xin Pan 已提交
381 382 383 384
    RequestBase* base = nullptr;
    {
      std::lock_guard<std::mutex> l(cq_mutex_);
      if (cq_name == "cq_get") {
X
Xin Pan 已提交
385
        base = get_reqs_[req_id];
X
Xin Pan 已提交
386
      } else if (cq_name == "cq_send") {
X
Xin Pan 已提交
387
        base = send_reqs_[req_id];
X
Xin Pan 已提交
388 389
      } else if (cq_name == "cq_prefetch") {
        base = prefetch_reqs_[req_id];
X
Xin Pan 已提交
390 391
      }
    }
G
gongweibao 已提交
392 393 394 395
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
396
    if (!ok) {
Q
qiaolongfei 已提交
397 398
      LOG(WARNING) << cq_name << " recv no regular event:argument name["
                   << base->GetReqName() << "]";
X
Xin Pan 已提交
399
      TryToRegisterNewOne(req_id);
G
gongweibao 已提交
400 401 402 403 404 405 406
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
Q
qiaolongfei 已提交
407
        VLOG(4) << cq_name << " PROCESS status:" << base->Status();
G
gongweibao 已提交
408 409 410
        break;
      }
      case FINISH: {
X
Xin Pan 已提交
411
        TryToRegisterNewOne(req_id);
Q
qiaolongfei 已提交
412
        VLOG(4) << cq_name << " FINISH status:" << base->Status();
G
gongweibao 已提交
413 414 415 416 417 418 419 420
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
421 422 423 424
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
425 426
}

T
typhoonzero 已提交
427
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
428
  {
T
typhoonzero 已提交
429 430
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
431
  }
T
typhoonzero 已提交
432
  barrier_condition_.notify_all();
G
gongweibao 已提交
433 434 435 436 437
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle