grpc_client.cc 8.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/operators/detail/grpc_client.h"

17
#include <sys/time.h>
Y
Yi Wang 已提交
18 19 20

#include <limits>

Y
Yi Wang 已提交
21
#include "paddle/fluid/framework/threadpool.h"
X
Xin Pan 已提交
22
#include "paddle/fluid/platform/profiler.h"
23

G
gongweibao 已提交
24 25 26 27
namespace paddle {
namespace operators {
namespace detail {

G
gongweibao 已提交
28
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
29

G
gongweibao 已提交
30
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
31 32
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
33
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
34 35
}

G
gongweibao 已提交
36
GRPCClient::~GRPCClient() {
W
Wu Yi 已提交
37 38 39 40 41 42 43 44 45
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
  }
  client_thread_->join();
Y
Yancey1989 已提交
46 47
}

G
gongweibao 已提交
48 49 50 51
bool GRPCClient::AsyncSendVar(const std::string& ep,
                              const platform::DeviceContext& ctx,
                              const framework::Scope& scope,
                              const std::string& var_name, int64_t time_out) {
52 53 54 55
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
56
  const auto ch = GetChannel(ep_val);
57

T
wip  
typhoonzero 已提交
58 59
  framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
                      this] {
60
    auto* var = p_scope->FindVar(var_name_val);
61 62 63

    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
64 65 66 67 68 69 70 71 72 73 74

    // varhandle
    VarHandle var_h;
    var_h.ep = ep_val;
    var_h.scope = p_scope;
    var_h.name = var_name_val;
    var_h.ctx = p_ctx;

    // stub context
    SendProcessor* s = new SendProcessor(ch);
    s->Prepare(var_h, time_out);
T
typhoonzero 已提交
75
    s->response_call_back_ = nullptr;
76

77 78
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
79
    call->StartCall();
Y
Yi Wang 已提交
80
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
81
  });
G
gongweibao 已提交
82 83 84 85 86 87
  req_count_++;

  return true;
}

void ProcGetResponse(const VarHandle& var_h,
88
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
89
  framework::Variable* outvar = nullptr;
Y
Yi Wang 已提交
90
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
91 92 93 94 95
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
96
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
97 98
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
99 100
}

G
gongweibao 已提交
101 102 103 104
bool GRPCClient::AsyncGetVar(const std::string& ep,
                             const platform::DeviceContext& ctx,
                             const framework::Scope& scope,
                             const std::string& var_name, int64_t time_out) {
105 106 107 108
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
109
  const auto ch = GetChannel(ep_val);
110

T
wip  
typhoonzero 已提交
111 112
  framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
                      this] {
Q
Qiao Longfei 已提交
113
    // prepare input
114 115
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
Q
Qiao Longfei 已提交
116 117
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
118

Q
Qiao Longfei 已提交
119
    // var handle
120 121 122 123 124 125 126 127 128 129 130
    VarHandle var_h;
    var_h.ep = ep_val;
    var_h.scope = p_scope;
    var_h.name = var_name_val;
    var_h.ctx = p_ctx;

    // stub context
    GetProcessor* s = new GetProcessor(ch);
    s->Prepare(var_h, time_out);
    s->response_call_back_ = ProcGetResponse;

131 132
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
133
    call->StartCall();
Y
Yi Wang 已提交
134
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
135
  });
G
gongweibao 已提交
136 137 138 139 140 141

  req_count_++;

  return true;
}

G
gongweibao 已提交
142 143 144 145 146 147
bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                  const platform::DeviceContext& ctx,
                                  const framework::Scope& scope,
                                  const std::string& in_var_name,
                                  const std::string& out_var_name,
                                  int64_t time_out) {
Q
Qiao Longfei 已提交
148 149 150 151 152
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
153
  const auto ch = GetChannel(ep_val);
Q
Qiao Longfei 已提交
154

T
wip  
typhoonzero 已提交
155 156
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
                      time_out, ch, this] {
Q
Qiao Longfei 已提交
157 158 159
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Y
Yancey1989 已提交
160
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
Q
Qiao Longfei 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174

    // var handle
    VarHandle var_h;
    var_h.ep = ep_val;
    var_h.scope = p_scope;
    var_h.name = out_var_name_val;
    var_h.ctx = p_ctx;

    // stub context
    GetProcessor* s = new GetProcessor(ch);
    s->Prepare(var_h, time_out);
    s->response_call_back_ = ProcGetResponse;

    auto call = s->stub_g_.PrepareUnaryCall(
175 176
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
177
    call->StartCall();
178
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
Q
Qiao Longfei 已提交
179 180 181 182 183 184
  });

  req_count_++;
  return true;
}

G
gongweibao 已提交
185 186
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                       int64_t time_out) {
Y
Yancey1989 已提交
187
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
188 189 190 191 192 193 194

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  s->Prepare(time_out);

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
195
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
196
  req_count_++;
197
}
Y
Yancey 已提交
198

G
gongweibao 已提交
199 200
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                       int64_t time_out) {
Y
Yancey1989 已提交
201
  const auto ch = GetChannel(ep);
202 203 204 205 206 207
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
  s->Prepare(time_out);

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
208
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
209
  req_count_++;
Y
Yancey 已提交
210 211
}

G
gongweibao 已提交
212
void GRPCClient::Wait() {
W
Wu Yi 已提交
213 214
  std::unique_lock<std::mutex> lk(sync_mutex_);
  sync_cond_.wait(lk, [this] { return req_count_ == 0; });
G
gongweibao 已提交
215 216
}

G
gongweibao 已提交
217
void GRPCClient::Proceed() {
W
Wu Yi 已提交
218
  void* tag = nullptr;
G
gongweibao 已提交
219 220
  bool ok = false;

W
Wu Yi 已提交
221 222 223 224 225 226 227 228 229 230
  while (cq_.Next(&tag, &ok)) {
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
    if (c->status_.ok()) {
      c->Process();
    } else {
      LOG(ERROR) << "var: " << c->var_h_.String()
                 << " grpc error:" << c->status_.error_message();
    }
G
gongweibao 已提交
231
    delete c;
W
Wu Yi 已提交
232 233 234 235 236
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
    }
    sync_cond_.notify_all();
G
gongweibao 已提交
237 238
  }
}
W
Wu Yi 已提交
239

G
gongweibao 已提交
240
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
Y
Yancey1989 已提交
241
  // TODO(Yancey1989): make grpc client completely thread-safe
W
Wu Yi 已提交
242
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
243
  auto it = channels_.find(ep);
G
gongweibao 已提交
244 245 246 247
  if (it != channels_.end()) {
    return it->second;
  }

G
gongweibao 已提交
248
  grpc::ChannelArguments args;
249
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
250 251 252
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
253 254
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
255
  channels_[ep] = ch;
G
gongweibao 已提交
256 257 258 259 260 261
  return ch;
}

}  // namespace detail
}  // namespace operators
}  // namespace paddle