grpc_client.cc 15.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
20 21
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
23
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25

26 27
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
28 29
namespace paddle {
namespace operators {
30
namespace distributed {
G
gongweibao 已提交
31

32
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
33 34
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
35 36
  PADDLE_ENFORCE(client_thread_ == nullptr,
                 "please not re init proceed thread");
G
gongweibao 已提交
37
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
38 39
}

Y
Yancey1989 已提交
40
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
41 42 43
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
44
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
45 46 47 48
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
49 50 51
  }
}

G
gongweibao 已提交
52
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
53
  stopped_ = true;
W
Wu Yi 已提交
54 55 56 57 58 59 60
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
61
    channels_.clear();
W
Wu Yi 已提交
62 63
  }
  client_thread_->join();
Y
Yancey1989 已提交
64 65
}

66 67 68 69 70
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
71 72 73 74
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
75
  const auto ch = GetChannel(ep_val);
76
  SendProcessor* s = new SendProcessor(ch);
77
  const std::string method = kSendRPC;
G
gongweibao 已提交
78
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
79
  s->Prepare(h, time_out);
80

G
gongweibao 已提交
81
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
82
    auto* var = p_scope->FindVar(var_name_val);
83 84

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
85
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
86

M
minqiyang 已提交
87
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
88 89

    // stub context
T
typhoonzero 已提交
90
    s->response_call_back_ = nullptr;
91

92
    platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
93

94 95
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
96
    call->StartCall();
Y
Yi Wang 已提交
97
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
98 99 100 101

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
102
  });
G
gongweibao 已提交
103 104
  req_count_++;

105
  return h;
G
gongweibao 已提交
106 107 108
}

void ProcGetResponse(const VarHandle& var_h,
109
                     const ::grpc::ByteBuffer& ret_msg) {
110
  VLOG(4) << "ProcGetResponse";
T
typhoonzero 已提交
111
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
112 113 114 115
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
116 117 118 119 120
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
121
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
122 123
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
124 125
}

126 127 128 129
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
130
                                     const std::string& out_varname,
Q
Qiao Longfei 已提交
131
                                     const std::string& table_name,
132
                                     int64_t time_out) {
133
  return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
Q
Qiao Longfei 已提交
134 135
                      "/sendrecv.SendRecvService/GetVariable", table_name,
                      time_out);
136 137
}

138 139 140 141 142 143 144 145 146
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    const std::string& out_varname, int64_t time_out) {
  std::string var_name_no_barrier =
      string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);

  return _AsyncGetVar(
      ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
Q
Qiao Longfei 已提交
147
      "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
148 149
}

150 151 152 153
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
154
  return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
Q
Qiao Longfei 已提交
155 156
                      "/sendrecv.SendRecvService/GetMonomerVariable", "",
                      time_out);
157 158
}

159 160 161 162
VarHandlePtr GRPCClient::_AsyncGetVar(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& method,
    const std::string& var_name, const std::string& out_varname,
Q
Qiao Longfei 已提交
163 164
    const std::string& rpc_path, const std::string& table_name,
    int64_t time_out) {
165 166 167
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
168
  const std::string out_varname_val = out_varname;
Q
Qiao Longfei 已提交
169
  const std::string table_name_val = table_name;
170
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
171
  const auto ch = GetChannel(ep_val);
172
  GetProcessor* s = new GetProcessor(ch);
173 174

  VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
175
  s->Prepare(h, time_out);
176

Q
Qiao Longfei 已提交
177 178 179 180 181 182 183 184 185 186
  framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method,
                      p_ctx, h, rpc_path, this] {
    // prepare input
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
    req.set_out_varname(out_varname_val);
    req.set_trainer_id(trainer_id_);
    req.set_table_name(table_name_val);
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
187

Q
Qiao Longfei 已提交
188
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
189

Q
Qiao Longfei 已提交
190 191
    // stub context
    s->response_call_back_ = ProcGetResponse;
192

Q
Qiao Longfei 已提交
193
    platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
194

Q
Qiao Longfei 已提交
195 196 197 198
    auto call =
        s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
    call->StartCall();
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
199

Q
Qiao Longfei 已提交
200 201 202 203
    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
  });
G
gongweibao 已提交
204 205 206

  req_count_++;

207
  return h;
G
gongweibao 已提交
208 209
}

210 211 212 213 214
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
215
                                          const std::string& table_name,
216
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
217 218 219 220
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
221
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
222
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
223
  const auto ch = GetChannel(ep_val);
224
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
225

226
  const std::string method = kPrefetchRPC;
G
gongweibao 已提交
227 228

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
229
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
230

T
wip  
typhoonzero 已提交
231
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
232
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
233 234 235
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
236
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
237
                          0, table_name_val);
Q
Qiao Longfei 已提交
238

M
minqiyang 已提交
239
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
240 241 242 243

    // stub context
    s->response_call_back_ = ProcGetResponse;

244
    platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
245

Q
Qiao Longfei 已提交
246
    auto call = s->stub_g_.PrepareUnaryCall(
247 248
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
249
    call->StartCall();
250
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
251 252 253 254

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
255 256 257
  });

  req_count_++;
258
  return h;
Q
Qiao Longfei 已提交
259 260
}

261 262
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
263
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
264 265

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
266
  const std::string method = kBatchBarrierRPC;
G
gongweibao 已提交
267 268
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
269
  s->Prepare(h, time_out);
Y
Yancey 已提交
270 271 272

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
273

274
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
275

Y
Yancey 已提交
276
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
277
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
278
  req_count_++;
G
gongweibao 已提交
279 280 281 282 283

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

284
  return h;
285
}
Y
Yancey 已提交
286

287 288
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
289
  const auto ch = GetChannel(ep);
290
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
291
  const std::string method = kFetchBarrierRPC;
G
gongweibao 已提交
292 293
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
294
  s->Prepare(h, time_out);
295 296 297

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
298

299
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
300

301
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
302
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
303
  req_count_++;
G
gongweibao 已提交
304 305 306 307 308

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

309
  return h;
Y
Yancey 已提交
310 311
}

312 313 314 315 316
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
317
  const std::string method = kSendMonomerFetchBarrierRPC;
318
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
319 320 321 322 323 324 325
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

326
  platform::RecordRPCEvent record_event(method);
327 328 329 330 331 332 333 334 335 336 337 338

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

339 340
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
341 342 343
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
344
  const std::string method = kSendCompleteRPC;
G
gongweibao 已提交
345
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
346
  s->Prepare(h, time_out);
W
Wu Yi 已提交
347 348

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
349
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
350

351
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
352

W
Wu Yi 已提交
353 354
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
355
  req_count_++;
G
gongweibao 已提交
356 357 358 359 360

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

361
  return h;
Y
Yancey1989 已提交
362 363
}

364 365 366
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
367
  const auto ch = GetChannel(ep);
368

T
tangwei12 已提交
369
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
370

371
  const std::string method = kCheckPointNotifyRPC;
G
gongweibao 已提交
372 373 374

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
375
  s->Prepare(h, time_out);
T
tangwei12 已提交
376

377 378
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
379
  req.set_out_varname(dir);
T
tangwei12 已提交
380

381
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
382

T
bug fix  
tangwei12 已提交
383
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
384 385
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
386 387 388 389 390

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

391
  return h;
T
tangwei12 已提交
392 393
}

Y
Yancey1989 已提交
394
bool GRPCClient::Wait() {
W
Wu Yi 已提交
395
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
396 397
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
398 399
}

G
gongweibao 已提交
400
void GRPCClient::Proceed() {
W
Wu Yi 已提交
401
  void* tag = nullptr;
G
gongweibao 已提交
402 403
  bool ok = false;

M
minqiyang 已提交
404
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
405
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
406 407 408
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
409

W
Wu Yi 已提交
410
    if (c->status_.ok()) {
M
minqiyang 已提交
411
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
412
      c->Process();
Y
Yancey1989 已提交
413
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
414
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
415 416 417
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
418 419 420 421
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
422
      c->Finish(false);
W
Wu Yi 已提交
423
    } else {
424
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
425 426 427 428
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

429
      c->Finish(false);
W
Wu Yi 已提交
430
    }
431

G
gongweibao 已提交
432
    bool notify = false;
W
Wu Yi 已提交
433 434 435
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
436 437 438 439 440 441 442
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
443
    }
G
gongweibao 已提交
444
  }
445 446 447 448 449 450 451

  // Last log message
  // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
  // static Mutex log_mutex is used for synchronization, which might have been
  // destructed at this moment.
  if (FLAGS_v >= 3) {
    std::string msg("GRPCClient Proceed end");
452
    fwrite(msg.c_str(), msg.length(), 1, stderr);
453
  }
G
gongweibao 已提交
454
}
W
Wu Yi 已提交
455

G
gongweibao 已提交
456
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
457
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
458
  auto it = channels_.find(ep);
G
gongweibao 已提交
459 460 461 462
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
463
  // Channel configurations:
G
gongweibao 已提交
464
  grpc::ChannelArguments args;
W
Wu Yi 已提交
465
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
466 467 468
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
469
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
470 471 472
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
473 474
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
475
  channels_[ep] = ch;
G
gongweibao 已提交
476 477 478
  return ch;
}

479
}  // namespace distributed
G
gongweibao 已提交
480 481
}  // namespace operators
}  // namespace paddle