grpc_client.cc 21.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
20 21
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
23
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25

26 27
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
28 29
namespace paddle {
namespace operators {
30
namespace distributed {
G
gongweibao 已提交
31

32
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
33 34
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
35 36
  PADDLE_ENFORCE(client_thread_ == nullptr,
                 "please not re init proceed thread");
G
gongweibao 已提交
37
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
38 39
}

Y
Yancey1989 已提交
40
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
41 42 43
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
44
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
45 46 47 48
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
49 50 51
  }
}

G
gongweibao 已提交
52
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
53
  stopped_ = true;
W
Wu Yi 已提交
54 55 56 57 58 59 60
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
61
    channels_.clear();
W
Wu Yi 已提交
62 63
  }
  client_thread_->join();
Y
Yancey1989 已提交
64 65
}

66 67 68 69 70
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
71 72 73 74
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
75
  const auto ch = GetChannel(ep_val);
76
  const std::string method = kSendRPC;
77

78 79 80 81 82 83
  int retry_times_ = 0;

  while (true) {
    SendProcessor* s = new SendProcessor(ch);
    VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
84

85 86
    framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
      auto* var = p_scope->FindVar(var_name_val);
87

88 89
      ::grpc::ByteBuffer req;
      SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
90

91
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
92

93 94
      // stub context
      s->response_call_back_ = nullptr;
G
gongweibao 已提交
95

96
      platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
97

98 99 100 101 102 103 104 105 106 107 108 109 110
      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
          &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
G
gongweibao 已提交
111
      h->Wait();
112 113 114 115 116 117 118
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
G
gongweibao 已提交
119
    }
G
gongweibao 已提交
120

121 122
    return h;
  }
G
gongweibao 已提交
123 124 125
}

void ProcGetResponse(const VarHandle& var_h,
126
                     const ::grpc::ByteBuffer& ret_msg) {
127
  VLOG(4) << "ProcGetResponse";
T
typhoonzero 已提交
128
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
129 130 131 132
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
133 134
}

135 136 137 138 139 140 141 142 143
void ProcGetRecvResponse(const VarHandle& var_h,
                         const ::grpc::ByteBuffer& ret_msg) {
  VLOG(4) << "ProcGetRecvResponse";
  framework::Variable* outvar = nullptr;
  int trainer_id;
  DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                                &trainer_id);
}

144 145 146
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
147
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
148 149
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
150 151
}

152 153 154 155
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
156
                                     const std::string& out_varname,
Q
Qiao Longfei 已提交
157
                                     const std::string& table_name,
158
                                     int64_t time_out) {
159
  return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
Q
Qiao Longfei 已提交
160 161
                      "/sendrecv.SendRecvService/GetVariable", table_name,
                      time_out);
162 163
}

164 165 166 167 168 169 170 171 172
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    const std::string& out_varname, int64_t time_out) {
  std::string var_name_no_barrier =
      string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);

  return _AsyncGetVar(
      ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
Q
Qiao Longfei 已提交
173
      "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
174 175
}

176 177 178 179
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
180
  return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
Q
Qiao Longfei 已提交
181 182
                      "/sendrecv.SendRecvService/GetMonomerVariable", "",
                      time_out);
183 184
}

185 186 187 188
VarHandlePtr GRPCClient::_AsyncGetVar(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& method,
    const std::string& var_name, const std::string& out_varname,
Q
Qiao Longfei 已提交
189 190
    const std::string& rpc_path, const std::string& table_name,
    int64_t time_out) {
191 192 193
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
194
  const std::string out_varname_val = out_varname;
Q
Qiao Longfei 已提交
195
  const std::string table_name_val = table_name;
196
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
197
  const auto ch = GetChannel(ep_val);
198

199 200 201 202
  int retry_times_ = 0;

  while (true) {
    GetProcessor* s = new GetProcessor(ch);
203

204 205
    VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
206

207 208 209 210 211 212 213 214 215 216
    framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s,
                        method, p_ctx, h, rpc_path, this] {
      // prepare input
      sendrecv::VariableMessage req;
      req.set_varname(var_name_val);
      req.set_out_varname(out_varname_val);
      req.set_trainer_id(trainer_id_);
      req.set_table_name(table_name_val);
      ::grpc::ByteBuffer buf;
      RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
217

218
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
219

220 221
      // stub context
      s->response_call_back_ = ProcGetResponse;
G
gongweibao 已提交
222

223 224 225 226 227 228 229 230 231 232 233 234
      platform::RecordRPCEvent record_event(method);

      auto call =
          s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;
G
gongweibao 已提交
235

236
    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
Q
Qiao Longfei 已提交
237
      h->Wait();
238 239 240 241 242 243 244
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
Q
Qiao Longfei 已提交
245
    }
G
gongweibao 已提交
246

247 248
    return h;
  }
G
gongweibao 已提交
249 250
}

251 252 253 254 255
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
256
                                          const std::string& table_name,
257
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
258 259 260 261
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
262
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
263
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
264
  const auto ch = GetChannel(ep_val);
G
gongweibao 已提交
265

266
  const std::string method = kPrefetchRPC;
267
  int retry_times_ = 0;
G
gongweibao 已提交
268

269 270 271
  while (true) {
    GetProcessor* s = new GetProcessor(ch);
    VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
T
tangwei12 已提交
272
    s->Prepare(h, kPrefetchTimeout);
Q
Qiao Longfei 已提交
273

274 275 276
    framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
                        p_ctx, s, method, h, table_name_val, this] {
      auto* var = p_scope->FindVar(in_var_name_val);
Q
Qiao Longfei 已提交
277

278 279 280
      ::grpc::ByteBuffer req;
      SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req,
                            out_var_name_val, 0, table_name_val);
Q
Qiao Longfei 已提交
281

282
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
283

284 285
      // stub context
      s->response_call_back_ = ProcGetResponse;
Q
Qiao Longfei 已提交
286

287
      platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
288

289 290 291 292 293
      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
          &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
294

295 296 297 298 299 300 301
      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
G
gongweibao 已提交
302
      h->Wait();
303 304 305 306 307 308 309
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
G
gongweibao 已提交
310
    }
Q
Qiao Longfei 已提交
311

312 313
    return h;
  }
Q
Qiao Longfei 已提交
314 315
}

316 317
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
T
tangwei12 已提交
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = kBatchBarrierRPC;
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

339
  return h;
340
}
Y
Yancey 已提交
341

342 343
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
T
tangwei12 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
  const auto ch = GetChannel(ep);
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
  const std::string method = kFetchBarrierRPC;
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

364
  return h;
Y
Yancey 已提交
365 366
}

367 368 369 370 371
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
372
  const std::string method = kSendMonomerFetchBarrierRPC;
373
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
374 375 376 377 378 379 380
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

381
  platform::RecordRPCEvent record_event(method);
382 383 384 385 386 387 388 389 390 391 392 393

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

394 395
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
T
tangwei12 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = kSendCompleteRPC;
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_trainer_id(trainer_id_);
  req.set_varname(COMPLETE_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

417
  return h;
Y
Yancey1989 已提交
418 419
}

420
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
421 422
                                               const std::string& dirname,
                                               const std::string& varname,
423
                                               int64_t time_out) {
T
tangwei12 已提交
424
  const auto ch = GetChannel(ep);
425

T
tangwei12 已提交
426
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
427

428
  const std::string method = kCheckPointNotifyRPC;
G
gongweibao 已提交
429 430 431

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
432
  s->Prepare(h, time_out);
T
tangwei12 已提交
433

434
  sendrecv::VariableMessage req;
435 436
  req.set_varname(varname);
  req.set_out_varname(dirname);
T
tangwei12 已提交
437

438
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
439

T
bug fix  
tangwei12 已提交
440
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
441 442
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
443 444 445 446 447

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

448
  return h;
T
tangwei12 已提交
449 450
}

1
123malin 已提交
451 452 453 454 455 456 457 458 459
VarHandlePtr GRPCClient::AsyncDistributeNotify(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
  const auto ch = GetChannel(ep_val);
460 461
  const std::string method = kRequestNotify;

1
123malin 已提交
462 463
  SendProcessor* s = new SendProcessor(ch);
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
464 465
  s->Prepare(h, time_out);

1
123malin 已提交
466
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
T
tangwei12 已提交
467
    auto* var = p_scope->FindVar(var_name_val);
468

T
tangwei12 已提交
469 470
    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
471

1
123malin 已提交
472 473 474 475 476 477 478 479
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

    // stub context
    s->response_call_back_ = nullptr;

    platform::RecordRPCEvent record_event(method);

    auto call = s->stub_g_.PrepareUnaryCall(
T
tangwei12 已提交
480
        s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
1
123malin 已提交
481 482 483 484
        &cq_);
    call->StartCall();
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  });
485 486 487 488 489 490 491 492 493
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& send_var_name,
                                          const std::string& recv_var_name,
                                          const std::string& table_name,
                                          int64_t time_out) {
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string send_var_name_val = send_var_name;
  const std::string recv_var_name_val = recv_var_name;
  const std::string table_name_val = table_name;
  const framework::Scope* p_scope = &scope;
  const auto ch = GetChannel(ep_val);
  const std::string method = kSendAndRecvRPC;
  VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: "
          << send_var_name_val << " Recv_var_name: " << recv_var_name_val;
  int retry_times_ = 0;

  while (true) {
    SendAndRecvProcessor* s = new SendAndRecvProcessor(ch);
    VarHandlePtr h(
        new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope));
    VarHandlePtr h_recv(
        new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
    s->RecvPrepare(h_recv);

    framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val,
                        p_scope, p_ctx, s, method, h, this] {
      auto* send_var = p_scope->FindVar(send_var_name_val);
      send_var->GetMutable<framework::LoDTensor>()->set_lod({});
      ::grpc::ByteBuffer buf;
      VLOG(4) << "SerializeToByteBuffer: send_var_name_val: "
              << send_var_name_val
              << " recv_var_name_val: " << recv_var_name_val;
      SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf,
                            recv_var_name_val, trainer_id_, table_name_val);

      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

      // stub context
      s->response_call_back_ = ProcGetRecvResponse;

      platform::RecordRPCEvent record_event(method);

      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable",
          buf, &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
      h->Wait();
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
    }

    return h;
  }
}

Y
Yancey1989 已提交
567
bool GRPCClient::Wait() {
W
Wu Yi 已提交
568
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
569 570
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
571 572
}

T
tangwei12 已提交
573 574 575 576 577 578 579 580 581 582 583 584
inline bool ShouldRetry(const std::string& method, int error_code) {
  if (method == kPrefetchRPC) {
    return true;
  }

  if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) {
    return true;
  }

  return false;
}

G
gongweibao 已提交
585
void GRPCClient::Proceed() {
W
Wu Yi 已提交
586
  void* tag = nullptr;
G
gongweibao 已提交
587 588
  bool ok = false;

M
minqiyang 已提交
589
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
590
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
591 592 593
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
594

W
Wu Yi 已提交
595
    if (c->status_.ok()) {
M
minqiyang 已提交
596
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
597
      c->Process();
T
tangwei12 已提交
598 599 600
    } else if (ShouldRetry(c->GetVarHandlePtr()->method(),
                           c->status_.error_code())) {
      VLOG(0) << c->GetVarHandlePtr()->String()
601 602 603 604 605 606
              << " meets grpc error, error_code:" << c->status_.error_code()
              << " error_message:" << c->status_.error_message()
              << " error_details:" << c->status_.error_details()
              << " should retry!";
      c->GetVarHandlePtr()->should_retry = true;
      c->Finish(false);
W
Wu Yi 已提交
607
    } else {
608 609 610 611 612
      PADDLE_THROW(platform::errors::External(
          "%s meets grpc error, error_code is %d, error message is %s, error "
          "details is %s.",
          c->GetVarHandlePtr()->String(), c->status_.error_code(),
          c->status_.error_message(), c->status_.error_details()));
613
      c->Finish(false);
W
Wu Yi 已提交
614
    }
615

G
gongweibao 已提交
616
    bool notify = false;
W
Wu Yi 已提交
617 618 619
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
620 621 622 623 624 625 626
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
627
    }
G
gongweibao 已提交
628
  }
629 630 631 632 633 634 635

  // Last log message
  // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
  // static Mutex log_mutex is used for synchronization, which might have been
  // destructed at this moment.
  if (FLAGS_v >= 3) {
    std::string msg("GRPCClient Proceed end");
636
    fwrite(msg.c_str(), msg.length(), 1, stderr);
637
  }
G
gongweibao 已提交
638
}
W
Wu Yi 已提交
639

G
gongweibao 已提交
640
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
641
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
642
  auto it = channels_.find(ep);
G
gongweibao 已提交
643 644 645 646
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
647
  // Channel configurations:
G
gongweibao 已提交
648
  grpc::ChannelArguments args;
W
Wu Yi 已提交
649
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
650 651 652
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
653
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
654 655 656
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
657 658
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
659
  channels_[ep] = ch;
G
gongweibao 已提交
660 661 662
  return ch;
}

663
}  // namespace distributed
G
gongweibao 已提交
664 665
}  // namespace operators
}  // namespace paddle