grpc_server.cc 14.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
16 17 18

#include <limits>
#include <string>
G
gongweibao 已提交
19

20
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
21

X
Xin Pan 已提交
22 23 24 25 26 27 28
DEFINE_int32(rpc_server_handle_send_threads, 20,
             "Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
             "Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
             "Number of threads used to handle prefetch at rpc server.");

G
gongweibao 已提交
29 30 31 32 33 34 35 36 37
namespace paddle {
namespace operators {
namespace detail {
enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
38
  explicit RequestBase(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
39
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
40
                       const platform::DeviceContext* dev_ctx)
Q
qiaolongfei 已提交
41 42 43 44 45
      : service_(service),
        cq_(cq),
        sync_mode_(sync_mode),
        status_(PROCESS),
        dev_ctx_(dev_ctx) {
G
gongweibao 已提交
46 47
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
48 49 50 51 52
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
53 54 55 56
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
57 58

 protected:
59 60 61
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
Q
qiaolongfei 已提交
62
  const bool sync_mode_;
G
gongweibao 已提交
63
  CallStatus status_;
64
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
65 66 67 68
};

class RequestSend final : public RequestBase {
 public:
69
  explicit RequestSend(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
70
                       ::grpc::ServerCompletionQueue* cq, bool sync_mode,
71
                       framework::Scope* scope, ReceivedQueue* queue,
X
Xin Pan 已提交
72
                       const platform::DeviceContext* dev_ctx, int req_id)
Q
qiaolongfei 已提交
73 74
      : RequestBase(service, cq, sync_mode, dev_ctx),
        queue_(queue),
X
Xin Pan 已提交
75
        responder_(&ctx_),
X
Xin Pan 已提交
76
        req_id_(req_id) {
Q
qiaolongfei 已提交
77
    if (sync_mode_) {
78
      request_.reset(new VariableResponse(scope, dev_ctx_, false));
Q
qiaolongfei 已提交
79
    } else {
80
      request_.reset(new VariableResponse(scope, dev_ctx_, true));
Q
qiaolongfei 已提交
81
    }
82
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
X
Xin Pan 已提交
83 84
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
X
Xin Pan 已提交
85
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
G
gongweibao 已提交
86 87 88 89
  }

  virtual ~RequestSend() {}

90
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
91

G
gongweibao 已提交
92
  virtual void Process() {
Q
qiaolongfei 已提交
93 94 95
    std::string var_name = GetReqName();
    VLOG(3) << "RequestSend " << var_name;
    queue_->Push(std::make_pair(var_name, request_));
96

G
gongweibao 已提交
97
    status_ = FINISH;
X
Xin Pan 已提交
98
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
99
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
100 101 102
  }

 protected:
X
Xin Pan 已提交
103
  sendrecv::VoidMessage reply_;
104 105
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
106
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
X
Xin Pan 已提交
107
  int req_id_;
G
gongweibao 已提交
108 109 110 111
};

class RequestGet final : public RequestBase {
 public:
112
  explicit RequestGet(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
113
                      ::grpc::ServerCompletionQueue* cq, bool sync_mode,
114
                      framework::Scope* scope,
T
typhoonzero 已提交
115
                      const platform::DeviceContext* dev_ctx,
X
Xin Pan 已提交
116 117
                      framework::BlockingQueue<MessageWithName>* queue,
                      int req_id)
Q
qiaolongfei 已提交
118
      : RequestBase(service, cq, sync_mode, dev_ctx),
Y
Yancey1989 已提交
119 120
        responder_(&ctx_),
        scope_(scope),
X
Xin Pan 已提交
121
        queue_(queue),
X
Xin Pan 已提交
122
        req_id_(req_id) {
Q
qiaolongfei 已提交
123
    auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
X
Xin Pan 已提交
124 125
    service_->RequestAsyncUnary(
        method_id, &ctx_, &request_, &responder_, cq_, cq_,
X
Xin Pan 已提交
126
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
G
gongweibao 已提交
127 128 129 130
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
131 132
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
133 134 135
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
Q
qiaolongfei 已提交
136
    VLOG(3) << "RequestGet " << var_name;
G
gongweibao 已提交
137
    auto* var = scope_->FindVar(var_name);
138

139
    if (var_name != FETCH_BARRIER_MESSAGE) {
X
Xin Pan 已提交
140
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
141
    }
142

G
gongweibao 已提交
143
    status_ = FINISH;
X
Xin Pan 已提交
144
    responder_.Finish(reply_, ::grpc::Status::OK,
X
Xin Pan 已提交
145
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
146 147 148 149 150 151

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
152 153 154 155
  }

 protected:
  sendrecv::VariableMessage request_;
X
Xin Pan 已提交
156
  ::grpc::ByteBuffer reply_;
157
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
158
  framework::Scope* scope_;
T
typhoonzero 已提交
159
  framework::BlockingQueue<MessageWithName>* queue_;
X
Xin Pan 已提交
160
  int req_id_;
G
gongweibao 已提交
161 162
};

163 164 165
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
Q
qiaolongfei 已提交
166
                           ::grpc::ServerCompletionQueue* cq, bool sync_mode,
167 168 169
                           framework::Scope* scope,
                           const platform::DeviceContext* dev_ctx,
                           framework::Executor* executor,
Y
Yancey1989 已提交
170
                           framework::ProgramDesc* program,
X
Xin Pan 已提交
171
                           framework::ExecutorPrepareContext* prefetch_ctx,
X
Xin Pan 已提交
172
                           int req_id)
Q
qiaolongfei 已提交
173
      : RequestBase(service, cq, sync_mode, dev_ctx),
174 175 176 177
        responder_(&ctx_),
        scope_(scope),
        executor_(executor),
        program_(program),
X
Xin Pan 已提交
178
        prefetch_ctx_(prefetch_ctx),
X
Xin Pan 已提交
179
        req_id_(req_id) {
Q
qiaolongfei 已提交
180 181
    // prefetch always create a new sub scope
    request_.reset(new VariableResponse(scope, dev_ctx_, true));
182
    int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
X
Xin Pan 已提交
183 184 185
    service_->RequestAsyncUnary(
        method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
        reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
186 187 188 189
  }

  virtual ~RequestPrefetch() {}

Y
Yancey1989 已提交
190
  virtual std::string GetReqName() { return request_->Varname(); }
191 192 193 194

  virtual void Process() {
    // prefetch process...

Y
Yancey1989 已提交
195
    std::string var_name = request_->OutVarname();
Q
qiaolongfei 已提交
196
    VLOG(3) << "RequestPrefetch " << var_name;
Y
Yancey1989 已提交
197
    auto var_desc = program_->Block(0).FindVar(var_name);
Q
qiaolongfei 已提交
198
    framework::Scope* local_scope = request_->GetMutableLocalScope();
Y
Yancey1989 已提交
199 200
    auto* var = local_scope->FindVar(var_name);
    InitializeVariable(var, var_desc->GetType());
Q
qiaolongfei 已提交
201
    executor_->RunPreparedContext(prefetch_ctx_, local_scope);
Y
Yancey1989 已提交
202

X
Xin Pan 已提交
203
    SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
Q
qiaolongfei 已提交
204

205
    status_ = FINISH;
X
Xin Pan 已提交
206 207
    responder_.Finish(reply_, ::grpc::Status::OK,
                      reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
208 209 210
  }

 protected:
Y
Yancey1989 已提交
211
  std::shared_ptr<VariableResponse> request_;
X
Xin Pan 已提交
212
  ::grpc::ByteBuffer reply_;
213 214 215 216
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  framework::Scope* scope_;
  framework::Executor* executor_;
  framework::ProgramDesc* program_;
Y
Yancey1989 已提交
217
  framework::ExecutorPrepareContext* prefetch_ctx_;
X
Xin Pan 已提交
218
  int req_id_;
219 220
};

T
typhoonzero 已提交
221
void AsyncGRPCServer::WaitClientGet(int count) {
222 223 224 225 226 227
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
228 229 230
  }
}

T
done  
typhoonzero 已提交
231
void AsyncGRPCServer::WaitServerReady() {
T
update  
typhoonzero 已提交
232
  std::unique_lock<std::mutex> lock(this->mutex_ready_);
T
done  
typhoonzero 已提交
233
  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
T
update  
typhoonzero 已提交
234 235
}

G
gongweibao 已提交
236
void AsyncGRPCServer::RunSyncUpdate() {
237
  ::grpc::ServerBuilder builder;
T
typhoonzero 已提交
238 239
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
                           &selected_port_);
G
gongweibao 已提交
240 241
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
242 243 244 245
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
246
  cq_prefetch_ = builder.AddCompletionQueue();
Y
Yancey 已提交
247

G
gongweibao 已提交
248
  server_ = builder.BuildAndStart();
T
typhoonzero 已提交
249 250
  LOG(INFO) << "Server listening on " << address_
            << " selected port: " << selected_port_;
G
gongweibao 已提交
251

X
Xin Pan 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265
  std::function<void(int)> send_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1);
  std::function<void(int)> get_register = std::bind(
      &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
  std::function<void(int)> prefetch_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
                std::placeholders::_1);

  for (int i = 0; i < kSendReqsBufSize; ++i) {
    TryToRegisterNewSendOne(i);
  }
  for (int i = 0; i < kGetReqsBufSize; ++i) {
    TryToRegisterNewGetOne(i);
  }
X
Xin Pan 已提交
266 267 268
  for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
    TryToRegisterNewPrefetchOne(i);
  }
X
Xin Pan 已提交
269

X
Xin Pan 已提交
270
  for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
X
Xin Pan 已提交
271 272 273 274
    t_sends_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_send_.get(), "cq_send", send_register)));
  }
X
Xin Pan 已提交
275
  for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
X
Xin Pan 已提交
276 277 278 279
    t_gets_.emplace_back(
        new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
                                  cq_get_.get(), "cq_get", get_register)));
  }
X
Xin Pan 已提交
280
  for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
X
Xin Pan 已提交
281 282 283 284
    t_prefetchs_.emplace_back(new std::thread(
        std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
                  "cq_prefetch", prefetch_register)));
  }
T
wip  
typhoonzero 已提交
285 286 287 288 289
  {
    std::lock_guard<std::mutex> lock(this->mutex_ready_);
    ready_ = 1;
  }
  condition_ready_.notify_all();
G
gongweibao 已提交
290 291
  // wait server
  server_->Wait();
X
Xin Pan 已提交
292
  for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
X
Xin Pan 已提交
293 294
    t_sends_[i]->join();
  }
X
Xin Pan 已提交
295
  for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
X
Xin Pan 已提交
296 297
    t_gets_[i]->join();
  }
X
Xin Pan 已提交
298
  for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
X
Xin Pan 已提交
299 300
    t_prefetchs_[i]->join();
  }
G
gongweibao 已提交
301 302 303 304 305 306
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
307
  cq_prefetch_->Shutdown();
G
gongweibao 已提交
308 309 310 311
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
T
typhoonzero 已提交
312
  is_shut_down_ = true;
G
gongweibao 已提交
313
  ShutdownQueue();
T
typhoonzero 已提交
314
  server_->Shutdown();
G
gongweibao 已提交
315 316
}

X
Xin Pan 已提交
317
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
G
gongweibao 已提交
318 319
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
320
    VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
G
gongweibao 已提交
321 322
    return;
  }
Q
qiaolongfei 已提交
323
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
X
Xin Pan 已提交
324 325
                                      scope_, &var_recv_queue_, dev_ctx_, i);
  send_reqs_[i] = static_cast<RequestBase*>(send);
Y
Yancey 已提交
326
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
327 328
}

X
Xin Pan 已提交
329
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
G
gongweibao 已提交
330 331
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
332
    VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
G
gongweibao 已提交
333 334
    return;
  }
Q
qiaolongfei 已提交
335
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
X
Xin Pan 已提交
336 337
                                   dev_ctx_, &var_get_queue_, req_id);
  get_reqs_[req_id] = static_cast<RequestBase*>(get);
Y
Yancey 已提交
338
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
339 340
}

X
Xin Pan 已提交
341
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
342 343
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
344
    VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
345 346
    return;
  }
X
Xin Pan 已提交
347 348
  RequestPrefetch* prefetch = new RequestPrefetch(
      &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
X
Xin Pan 已提交
349
      program_, prefetch_ctx_.get(), req_id);
X
Xin Pan 已提交
350
  prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
351 352 353 354

  VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}

Y
Yancey 已提交
355
// FIXME(typhoonzero): change cq_name to enum.
X
Xin Pan 已提交
356 357 358
void AsyncGRPCServer::HandleRequest(
    ::grpc::ServerCompletionQueue* cq, const std::string& cq_name,
    std::function<void(int)> TryToRegisterNewOne) {
G
gongweibao 已提交
359 360
  void* tag = NULL;
  bool ok = false;
361

G
gongweibao 已提交
362
  while (true) {
Q
qiaolongfei 已提交
363
    VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
G
gongweibao 已提交
364
    if (!cq->Next(&tag, &ok)) {
T
typhoonzero 已提交
365
      LOG(INFO) << cq_name << " CompletionQueue shutdown!";
G
gongweibao 已提交
366 367
      break;
    }
Q
qiaolongfei 已提交
368
    VLOG(3) << "HandleRequest for " << cq_name << " get Next";
X
Xin Pan 已提交
369
    int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
Q
qiaolongfei 已提交
370

Q
qiaolongfei 已提交
371 372 373 374
    if (sync_mode_) {
      // FIXME(typhoonzero): de-couple the barriers with recv_op
      if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
      if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
Q
qiaolongfei 已提交
375
      VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
Q
qiaolongfei 已提交
376
    }
G
gongweibao 已提交
377

X
Xin Pan 已提交
378 379 380 381
    RequestBase* base = nullptr;
    {
      std::lock_guard<std::mutex> l(cq_mutex_);
      if (cq_name == "cq_get") {
X
Xin Pan 已提交
382
        base = get_reqs_[req_id];
X
Xin Pan 已提交
383
      } else if (cq_name == "cq_send") {
X
Xin Pan 已提交
384
        base = send_reqs_[req_id];
X
Xin Pan 已提交
385 386
      } else if (cq_name == "cq_prefetch") {
        base = prefetch_reqs_[req_id];
X
Xin Pan 已提交
387 388
      }
    }
G
gongweibao 已提交
389 390 391 392
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
393
    if (!ok) {
Q
qiaolongfei 已提交
394 395
      LOG(WARNING) << cq_name << " recv no regular event:argument name["
                   << base->GetReqName() << "]";
X
Xin Pan 已提交
396
      TryToRegisterNewOne(req_id);
G
gongweibao 已提交
397 398 399 400 401 402 403
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        base->Process();
Q
qiaolongfei 已提交
404
        VLOG(4) << cq_name << " PROCESS status:" << base->Status();
G
gongweibao 已提交
405 406 407
        break;
      }
      case FINISH: {
X
Xin Pan 已提交
408
        TryToRegisterNewOne(req_id);
Q
qiaolongfei 已提交
409
        VLOG(4) << cq_name << " FINISH status:" << base->Status();
G
gongweibao 已提交
410 411 412 413 414 415 416 417
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
418 419 420 421
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
422 423
}

T
typhoonzero 已提交
424
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
425
  {
T
typhoonzero 已提交
426 427
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
428
  }
T
typhoonzero 已提交
429
  barrier_condition_.notify_all();
G
gongweibao 已提交
430 431 432 433 434
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle