grpc_client.cc 12.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <sys/time.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
20
#include "paddle/fluid/operators/distributed/grpc_client.h"
21
#include "paddle/fluid/operators/distributed/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28

G
gongweibao 已提交
29
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
32 33
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
34
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
35 36
}

Y
Yancey1989 已提交
37
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
38 39 40 41 42 43 44 45
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
      VLOG(3) << "send complete message to " << it.first;
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
46 47 48
  }
}

G
gongweibao 已提交
49
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
50
  stopped_ = true;
W
Wu Yi 已提交
51 52 53 54 55 56 57
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
58
    channels_.clear();
W
Wu Yi 已提交
59 60
  }
  client_thread_->join();
Y
Yancey1989 已提交
61 62
}

63 64 65 66 67
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
68 69 70 71
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
72
  const auto ch = GetChannel(ep_val);
73
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
74 75
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
76
  s->Prepare(h, time_out);
77

G
gongweibao 已提交
78
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
79
    auto* var = p_scope->FindVar(var_name_val);
80 81 82

    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
83

84
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
85 86

    // stub context
T
typhoonzero 已提交
87
    s->response_call_back_ = nullptr;
88

G
gongweibao 已提交
89 90
    platform::RecordEvent record_event(method, p_ctx);

91 92
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
93
    call->StartCall();
Y
Yi Wang 已提交
94
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
95 96 97 98

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
99
  });
G
gongweibao 已提交
100 101
  req_count_++;

102
  return h;
G
gongweibao 已提交
103 104 105
}

void ProcGetResponse(const VarHandle& var_h,
106
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
107
  framework::Variable* outvar = nullptr;
108
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
109 110 111 112 113
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
114
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
115 116
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
117 118
}

119 120 121 122 123
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
124 125 126 127
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
128
  const auto ch = GetChannel(ep_val);
129
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
130 131
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
132
  s->Prepare(h, time_out);
133

G
gongweibao 已提交
134
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
135
    // prepare input
136 137
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
Q
Qiao Longfei 已提交
138 139
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
140

141
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
142 143 144 145

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
146 147
    platform::RecordEvent record_event(method, p_ctx);

148 149
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
150
    call->StartCall();
Y
Yi Wang 已提交
151
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
152 153 154 155

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
156
  });
G
gongweibao 已提交
157 158 159

  req_count_++;

160
  return h;
G
gongweibao 已提交
161 162
}

163 164 165 166 167 168
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
169 170 171 172 173
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
174
  const auto ch = GetChannel(ep_val);
175
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
176 177 178 179

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
180
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
181

T
wip  
typhoonzero 已提交
182
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
G
gongweibao 已提交
183
                      s, method, h, this] {
Q
Qiao Longfei 已提交
184 185 186
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Y
Yancey1989 已提交
187
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
Q
Qiao Longfei 已提交
188

189
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
190 191 192 193

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
194 195
    platform::RecordEvent record_event(method, p_ctx);

Q
Qiao Longfei 已提交
196
    auto call = s->stub_g_.PrepareUnaryCall(
197 198
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
199
    call->StartCall();
200
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
201 202 203 204

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
205 206 207
  });

  req_count_++;
208
  return h;
Q
Qiao Longfei 已提交
209 210
}

211 212
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
213
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
214 215

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
216 217 218
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
219
  s->Prepare(h, time_out);
Y
Yancey 已提交
220 221 222

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
223 224 225

  platform::RecordEvent record_event(method, nullptr);

Y
Yancey 已提交
226
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
227
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
228
  req_count_++;
G
gongweibao 已提交
229 230 231 232 233

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

234
  return h;
235
}
Y
Yancey 已提交
236

237 238
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
239
  const auto ch = GetChannel(ep);
240
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
241 242 243
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
244
  s->Prepare(h, time_out);
245 246 247

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
248 249 250

  platform::RecordEvent record_event(method, nullptr);

251
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
252
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
253
  req_count_++;
G
gongweibao 已提交
254 255 256 257 258

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

259
  return h;
Y
Yancey 已提交
260 261
}

262 263
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
264 265 266
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
267 268
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
269
  s->Prepare(h, time_out);
W
Wu Yi 已提交
270 271

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
272
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
273 274 275

  platform::RecordEvent record_event(method, nullptr);

W
Wu Yi 已提交
276 277
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
278
  req_count_++;
G
gongweibao 已提交
279 280 281 282 283

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

284
  return h;
Y
Yancey1989 已提交
285 286
}

287 288 289
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
290
  const auto ch = GetChannel(ep);
291

T
tangwei12 已提交
292
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
293 294 295 296 297

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
298
  s->Prepare(h, time_out);
T
tangwei12 已提交
299

300 301
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
302
  req.set_out_varname(dir);
T
tangwei12 已提交
303

G
gongweibao 已提交
304 305
  platform::RecordEvent record_event(method, nullptr);

T
bug fix  
tangwei12 已提交
306
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
307 308
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
309 310 311 312 313

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

314
  return h;
T
tangwei12 已提交
315 316
}

Y
Yancey1989 已提交
317
bool GRPCClient::Wait() {
W
Wu Yi 已提交
318
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
319 320
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
321 322
}

G
gongweibao 已提交
323
void GRPCClient::Proceed() {
W
Wu Yi 已提交
324
  void* tag = nullptr;
G
gongweibao 已提交
325 326
  bool ok = false;

327
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
328
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
329 330 331
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
332

W
Wu Yi 已提交
333
    if (c->status_.ok()) {
334
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
335
      c->Process();
Y
Yancey1989 已提交
336
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
G
gongweibao 已提交
337
      // FIXME(gongwb): parse error_details?
338
      LOG(ERROR) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
339 340 341
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
342 343 344 345
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
346
      c->Finish(false);
W
Wu Yi 已提交
347
    } else {
348
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
349 350 351 352
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

353
      c->Finish(false);
W
Wu Yi 已提交
354
    }
355

G
gongweibao 已提交
356
    bool notify = false;
W
Wu Yi 已提交
357 358 359
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
360 361 362 363 364 365 366
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
367
    }
G
gongweibao 已提交
368
  }
369
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
370
}
W
Wu Yi 已提交
371

G
gongweibao 已提交
372
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
373
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
374
  auto it = channels_.find(ep);
G
gongweibao 已提交
375 376 377 378
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
379
  // Channel configurations:
G
gongweibao 已提交
380
  grpc::ChannelArguments args;
W
Wu Yi 已提交
381
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
382
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
383 384 385
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
386 387
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
388
  channels_[ep] = ch;
G
gongweibao 已提交
389 390 391
  return ch;
}

392
}  // namespace distributed
G
gongweibao 已提交
393 394
}  // namespace operators
}  // namespace paddle