grpc_server.cc 10.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
G
gongweibao 已提交
16

17
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace operators {
namespace detail {

enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
29 30 31 32
  explicit RequestBase(GrpcService::AsyncService* service,
                       ::grpc::ServerCompletionQueue* cq,
                       const platform::DeviceContext* dev_ctx)
      : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
G
gongweibao 已提交
33 34
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
35 36 37 38 39
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
40 41 42 43
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
44 45

 protected:
46 47 48
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
49
  CallStatus status_;
50
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
51 52 53 54
};

class RequestSend final : public RequestBase {
 public:
55 56 57 58 59 60 61 62 63
  explicit RequestSend(GrpcService::AsyncService* service,
                       ::grpc::ServerCompletionQueue* cq,
                       framework::Scope* scope, ReceivedQueue* queue,
                       const platform::DeviceContext* dev_ctx)
      : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
    request_.reset(new VariableResponse(scope, dev_ctx_));
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
    service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
                                cq_, cq_, this);
G
gongweibao 已提交
64 65 66 67
  }

  virtual ~RequestSend() {}

68
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
69

G
gongweibao 已提交
70
  virtual void Process() {
71 72 73 74
    queue_->Push(std::make_pair(request_->Varname(), request_));

    sendrecv::VoidMessage reply;
    responder_.Finish(reply, ::grpc::Status::OK, this);
G
gongweibao 已提交
75
    status_ = FINISH;
G
gongweibao 已提交
76 77 78
  }

 protected:
79 80
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
81 82 83 84 85
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
86 87 88
  explicit RequestGet(GrpcService::AsyncService* service,
                      ::grpc::ServerCompletionQueue* cq,
                      framework::Scope* scope,
T
typhoonzero 已提交
89
                      const platform::DeviceContext* dev_ctx,
90
                      SimpleBlockQueue<MessageWithName>* queue)
91
      : RequestBase(service, cq, dev_ctx),
Y
Yancey1989 已提交
92 93
        responder_(&ctx_),
        scope_(scope),
T
typhoonzero 已提交
94
        queue_(queue) {
95 96 97
    int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
    service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
                                cq_, this);
G
gongweibao 已提交
98 99 100 101
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
102 103
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
104 105 106 107
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
    auto* var = scope_->FindVar(var_name);
108 109

    ::grpc::ByteBuffer reply;
110
    if (var_name != FETCH_BARRIER_MESSAGE) {
111
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
112
    }
113 114

    responder_.Finish(reply, ::grpc::Status::OK, this);
G
gongweibao 已提交
115
    status_ = FINISH;
116 117 118 119 120 121

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
122 123 124 125
  }

 protected:
  sendrecv::VariableMessage request_;
126
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
127
  framework::Scope* scope_;
128
  SimpleBlockQueue<MessageWithName>* queue_;
G
gongweibao 已提交
129 130
};

131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
class RequestPrefetch final : public RequestBase {
 public:
  explicit RequestPrefetch(GrpcService::AsyncService* service,
                           ::grpc::ServerCompletionQueue* cq,
                           framework::Scope* scope,
                           const platform::DeviceContext* dev_ctx,
                           framework::Executor* executor,
                           framework::ProgramDesc* program, int blkid)
      : RequestBase(service, cq, dev_ctx),
        responder_(&ctx_),
        scope_(scope),
        executor_(executor),
        program_(program),
        blkid_(blkid) {
    int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
    service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
                                cq_, this);
  }

  virtual ~RequestPrefetch() {}

  virtual std::string GetReqName() { return request_.varname(); }

  virtual void Process() {
    // prefetch process...
    ::grpc::ByteBuffer relay;
    // TODO(Yancey1989): execute the Block which containers prefetch ops

    responder_.Finish(relay, ::grpc::Status::OK, this);
    status_ = FINISH;
  }

 protected:
  sendrecv::VariableMessage request_;
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
  framework::Scope* scope_;
  framework::Executor* executor_;
  framework::ProgramDesc* program_;
  int blkid_;
};

T
typhoonzero 已提交
172
void AsyncGRPCServer::WaitClientGet(int count) {
173 174 175 176 177 178
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
179 180 181
  }
}

G
gongweibao 已提交
182
void AsyncGRPCServer::RunSyncUpdate() {
183 184
  ::grpc::ServerBuilder builder;
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials());
G
gongweibao 已提交
185 186
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
187 188 189 190
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
191
  cq_prefetch_ = builder.AddCompletionQueue();
Y
Yancey 已提交
192

G
gongweibao 已提交
193 194 195 196 197 198 199
  server_ = builder.BuildAndStart();
  LOG(INFO) << "Server listening on " << address_ << std::endl;

  std::function<void()> send_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
  std::function<void()> get_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
200 201
  std::function<void()> prefetch_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
G
gongweibao 已提交
202 203

  t_send_.reset(
Y
Yancey 已提交
204
      new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
G
gongweibao 已提交
205 206 207
                                cq_send_.get(), "cq_send", send_register)));

  t_get_.reset(
Y
Yancey 已提交
208
      new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
G
gongweibao 已提交
209
                                cq_get_.get(), "cq_get", get_register)));
210 211 212
  t_prefetch_.reset(new std::thread(
      std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
                "cq_prefetch", prefetch_register)));
G
gongweibao 已提交
213 214 215 216
  // wait server
  server_->Wait();
  t_send_->join();
  t_get_->join();
217
  t_prefetch_->join();
G
gongweibao 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
  is_shut_down_ = true;
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
  server_->Shutdown();
  ShutdownQueue();
}

void AsyncGRPCServer::TryToRegisterNewSendOne() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
    return;
  }
238 239
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
                                      &var_recv_queue_, dev_ctx_);
Y
Yancey 已提交
240
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
241 242 243 244 245 246 247
}

void AsyncGRPCServer::TryToRegisterNewGetOne() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
    return;
  }
T
typhoonzero 已提交
248 249
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
                                   &var_get_queue_);
Y
Yancey 已提交
250
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
251 252
}

253 254 255 256 257 258 259 260 261 262 263 264
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
    return;
  }
  RequestPrefetch* prefetch =
      new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
                          executor_, program_, prefetch_blk_id_);

  VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}

Y
Yancey 已提交
265
// FIXME(typhoonzero): change cq_name to enum.
266
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
G
gongweibao 已提交
267 268 269 270 271 272 273 274 275 276 277 278
                                    std::string cq_name,
                                    std::function<void()> TryToRegisterNewOne) {
  TryToRegisterNewOne();

  void* tag = NULL;
  bool ok = false;
  while (true) {
    if (!cq->Next(&tag, &ok)) {
      LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
      break;
    }

G
gongweibao 已提交
279
    PADDLE_ENFORCE(tag);
T
typhoonzero 已提交
280 281
    // FIXME(typhoonzero): de-couple the barriers with recv_op
    if (cq_name == "cq_get") WaitCond(1);
T
typhoonzero 已提交
282
    if (cq_name == "cq_send") WaitCond(0);
G
gongweibao 已提交
283 284

    RequestBase* base = (RequestBase*)tag;
G
gongweibao 已提交
285 286 287 288
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
289
    if (!ok) {
G
gongweibao 已提交
290 291
      LOG(WARNING) << cq_name << " recv no regular event:argument name"
                   << base->GetReqName();
G
gongweibao 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
      TryToRegisterNewOne();
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        VLOG(4) << cq_name << " status:" << base->Status();
        TryToRegisterNewOne();
        base->Process();
        break;
      }
      case FINISH: {
        VLOG(4) << cq_name << " status:" << base->Status();
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
314 315 316 317
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
318 319
}

T
typhoonzero 已提交
320
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
321
  {
T
typhoonzero 已提交
322 323
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
324
  }
T
typhoonzero 已提交
325
  barrier_condition_.notify_all();
G
gongweibao 已提交
326 327 328 329 330
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle