grpc_server.cc 8.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/grpc_server.h"
G
gongweibao 已提交
16

17
using ::grpc::ServerAsyncResponseWriter;
G
gongweibao 已提交
18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace operators {
namespace detail {

enum CallStatus { PROCESS = 0, FINISH };

// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
 public:
29 30 31 32
  explicit RequestBase(GrpcService::AsyncService* service,
                       ::grpc::ServerCompletionQueue* cq,
                       const platform::DeviceContext* dev_ctx)
      : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
G
gongweibao 已提交
33 34
    PADDLE_ENFORCE(cq_);
  }
G
gongweibao 已提交
35 36 37 38 39
  virtual ~RequestBase() {}
  virtual void Process() { assert(false); }

  CallStatus Status() { return status_; }
  void SetStatus(CallStatus status) { status_ = status; }
T
typhoonzero 已提交
40 41 42 43
  virtual std::string GetReqName() {
    assert(false);
    return "";
  }
G
gongweibao 已提交
44 45

 protected:
46 47 48
  ::grpc::ServerContext ctx_;
  GrpcService::AsyncService* service_;
  ::grpc::ServerCompletionQueue* cq_;
G
gongweibao 已提交
49
  CallStatus status_;
50
  const platform::DeviceContext* dev_ctx_;
G
gongweibao 已提交
51 52 53 54
};

class RequestSend final : public RequestBase {
 public:
55 56 57 58 59 60 61 62 63
  explicit RequestSend(GrpcService::AsyncService* service,
                       ::grpc::ServerCompletionQueue* cq,
                       framework::Scope* scope, ReceivedQueue* queue,
                       const platform::DeviceContext* dev_ctx)
      : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
    request_.reset(new VariableResponse(scope, dev_ctx_));
    int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
    service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
                                cq_, cq_, this);
G
gongweibao 已提交
64 65 66 67
  }

  virtual ~RequestSend() {}

68
  virtual std::string GetReqName() { return request_->Varname(); }
G
gongweibao 已提交
69

G
gongweibao 已提交
70
  virtual void Process() {
71 72 73 74
    queue_->Push(std::make_pair(request_->Varname(), request_));

    sendrecv::VoidMessage reply;
    responder_.Finish(reply, ::grpc::Status::OK, this);
G
gongweibao 已提交
75
    status_ = FINISH;
G
gongweibao 已提交
76 77 78
  }

 protected:
79 80
  std::shared_ptr<VariableResponse> request_;
  ReceivedQueue* queue_;
G
gongweibao 已提交
81 82 83 84 85
  ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

class RequestGet final : public RequestBase {
 public:
86 87 88
  explicit RequestGet(GrpcService::AsyncService* service,
                      ::grpc::ServerCompletionQueue* cq,
                      framework::Scope* scope,
T
typhoonzero 已提交
89
                      const platform::DeviceContext* dev_ctx,
90
                      SimpleBlockQueue<MessageWithName>* queue)
91
      : RequestBase(service, cq, dev_ctx),
Y
Yancey1989 已提交
92 93
        responder_(&ctx_),
        scope_(scope),
T
typhoonzero 已提交
94
        queue_(queue) {
95 96 97
    int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
    service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
                                cq_, this);
G
gongweibao 已提交
98 99 100 101
  }

  virtual ~RequestGet() {}

G
gongweibao 已提交
102 103
  virtual std::string GetReqName() { return request_.varname(); }

G
gongweibao 已提交
104 105 106 107
  virtual void Process() {
    // proc request.
    std::string var_name = request_.varname();
    auto* var = scope_->FindVar(var_name);
108 109

    ::grpc::ByteBuffer reply;
110
    if (var_name != FETCH_BARRIER_MESSAGE) {
111
      SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
112
    }
113 114

    responder_.Finish(reply, ::grpc::Status::OK, this);
G
gongweibao 已提交
115
    status_ = FINISH;
116 117 118 119 120 121

    if (var_name == FETCH_BARRIER_MESSAGE) {
      sendrecv::VariableMessage msg;
      MessageWithName msg_with_name = std::make_pair(var_name, msg);
      queue_->Push(msg_with_name);
    }
G
gongweibao 已提交
122 123 124 125
  }

 protected:
  sendrecv::VariableMessage request_;
126
  ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
G
gongweibao 已提交
127
  framework::Scope* scope_;
128
  SimpleBlockQueue<MessageWithName>* queue_;
G
gongweibao 已提交
129 130
};

T
typhoonzero 已提交
131
void AsyncGRPCServer::WaitClientGet(int count) {
132 133 134 135 136 137
  int fetch_barriers = 0;
  while (fetch_barriers < count) {
    auto msg = var_get_queue_.Pop();
    if (msg.first == FETCH_BARRIER_MESSAGE) {
      fetch_barriers++;
    }
T
typhoonzero 已提交
138 139 140
  }
}

G
gongweibao 已提交
141
void AsyncGRPCServer::RunSyncUpdate() {
142 143
  ::grpc::ServerBuilder builder;
  builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials());
G
gongweibao 已提交
144 145
  builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
G
gongweibao 已提交
146 147 148 149
  builder.RegisterService(&service_);

  cq_send_ = builder.AddCompletionQueue();
  cq_get_ = builder.AddCompletionQueue();
Y
Yancey 已提交
150

G
gongweibao 已提交
151 152 153 154 155 156 157 158 159
  server_ = builder.BuildAndStart();
  LOG(INFO) << "Server listening on " << address_ << std::endl;

  std::function<void()> send_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
  std::function<void()> get_register =
      std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);

  t_send_.reset(
Y
Yancey 已提交
160
      new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
G
gongweibao 已提交
161 162 163
                                cq_send_.get(), "cq_send", send_register)));

  t_get_.reset(
Y
Yancey 已提交
164
      new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
G
gongweibao 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
                                cq_get_.get(), "cq_get", get_register)));

  // wait server
  server_->Wait();
  t_send_->join();
  t_get_->join();
}

void AsyncGRPCServer::ShutdownQueue() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  cq_send_->Shutdown();
  cq_get_->Shutdown();
  is_shut_down_ = true;
}

// This URL explains why shutdown is complicate:
void AsyncGRPCServer::ShutDown() {
  server_->Shutdown();
  ShutdownQueue();
}

void AsyncGRPCServer::TryToRegisterNewSendOne() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
    return;
  }
191 192
  RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
                                      &var_recv_queue_, dev_ctx_);
Y
Yancey 已提交
193
  VLOG(4) << "Create RequestSend status:" << send->Status();
G
gongweibao 已提交
194 195 196 197 198 199 200
}

void AsyncGRPCServer::TryToRegisterNewGetOne() {
  std::unique_lock<std::mutex> lock(cq_mutex_);
  if (is_shut_down_) {
    return;
  }
T
typhoonzero 已提交
201 202
  RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
                                   &var_get_queue_);
Y
Yancey 已提交
203
  VLOG(4) << "Create RequestGet status:" << get->Status();
G
gongweibao 已提交
204 205
}

Y
Yancey 已提交
206
// FIXME(typhoonzero): change cq_name to enum.
207
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
G
gongweibao 已提交
208 209 210 211 212 213 214 215 216 217 218 219
                                    std::string cq_name,
                                    std::function<void()> TryToRegisterNewOne) {
  TryToRegisterNewOne();

  void* tag = NULL;
  bool ok = false;
  while (true) {
    if (!cq->Next(&tag, &ok)) {
      LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
      break;
    }

G
gongweibao 已提交
220
    PADDLE_ENFORCE(tag);
T
typhoonzero 已提交
221 222
    // FIXME(typhoonzero): de-couple the barriers with recv_op
    if (cq_name == "cq_get") WaitCond(1);
T
typhoonzero 已提交
223
    if (cq_name == "cq_send") WaitCond(0);
G
gongweibao 已提交
224 225

    RequestBase* base = (RequestBase*)tag;
G
gongweibao 已提交
226 227 228 229
    // reference:
    // https://github.com/tensorflow/tensorflow/issues/5596
    // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
    // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
G
gongweibao 已提交
230
    if (!ok) {
G
gongweibao 已提交
231 232
      LOG(WARNING) << cq_name << " recv no regular event:argument name"
                   << base->GetReqName();
G
gongweibao 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
      TryToRegisterNewOne();
      delete base;
      continue;
    }

    switch (base->Status()) {
      case PROCESS: {
        VLOG(4) << cq_name << " status:" << base->Status();
        TryToRegisterNewOne();
        base->Process();
        break;
      }
      case FINISH: {
        VLOG(4) << cq_name << " status:" << base->Status();
        delete base;
        break;
      }
      default: { assert(false); }
    }
  }
}

T
typhoonzero 已提交
255 256 257 258
void AsyncGRPCServer::WaitCond(int cond) {
  std::unique_lock<std::mutex> lock(this->barrier_mutex_);
  barrier_condition_.wait(lock,
                          [=] { return this->barrier_cond_step_ == cond; });
G
gongweibao 已提交
259 260
}

T
typhoonzero 已提交
261
void AsyncGRPCServer::SetCond(int cond) {
G
gongweibao 已提交
262
  {
T
typhoonzero 已提交
263 264
    std::lock_guard<std::mutex> lock(this->barrier_mutex_);
    barrier_cond_step_ = cond;
G
gongweibao 已提交
265
  }
T
typhoonzero 已提交
266
  barrier_condition_.notify_all();
G
gongweibao 已提交
267 268 269 270 271
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle