grpc_client.cc 22.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
20 21
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
23
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25

26 27
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
28 29
namespace paddle {
namespace operators {
30
namespace distributed {
G
gongweibao 已提交
31

32
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
33 34
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
M
MRXLT 已提交
35 36 37
  PADDLE_ENFORCE_EQ(client_thread_ == nullptr, true,
                    platform::errors::PreconditionNotMet(
                        "please not re init proceed thread"));
G
gongweibao 已提交
38
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
39 40
}

Y
Yancey1989 已提交
41
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
42 43 44
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
45
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
46 47
      this->AsyncSendComplete(it.first);
    }
M
MRXLT 已提交
48 49
    PADDLE_ENFORCE_EQ(this->Wait(), true, platform::errors::PreconditionNotMet(
                                              "internal grpc service error."));
Y
Yancey1989 已提交
50
    completed_ = true;
W
Wu Yi 已提交
51 52 53
  }
}

G
gongweibao 已提交
54
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
55
  stopped_ = true;
W
Wu Yi 已提交
56 57 58 59 60 61 62
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
63
    channels_.clear();
W
Wu Yi 已提交
64 65
  }
  client_thread_->join();
Y
Yancey1989 已提交
66 67
}

68 69 70 71 72
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
73 74 75 76
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
77
  const auto ch = GetChannel(ep_val);
78
  const std::string method = kSendRPC;
79

80 81 82 83 84 85
  int retry_times_ = 0;

  while (true) {
    SendProcessor* s = new SendProcessor(ch);
    VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
86

87 88
    framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
      auto* var = p_scope->FindVar(var_name_val);
89

90 91
      ::grpc::ByteBuffer req;
      SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
92

93
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
94

95 96
      // stub context
      s->response_call_back_ = nullptr;
G
gongweibao 已提交
97

98
      platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
99

100 101 102 103 104 105 106 107 108 109 110 111 112
      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
          &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
G
gongweibao 已提交
113
      h->Wait();
114 115 116 117 118 119 120
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
G
gongweibao 已提交
121
    }
G
gongweibao 已提交
122

123 124
    return h;
  }
G
gongweibao 已提交
125 126 127
}

void ProcGetResponse(const VarHandle& var_h,
128
                     const ::grpc::ByteBuffer& ret_msg) {
129
  VLOG(4) << "ProcGetResponse";
T
typhoonzero 已提交
130
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
131 132 133 134
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
135 136
}

137 138 139 140 141 142 143 144 145
void ProcGetRecvResponse(const VarHandle& var_h,
                         const ::grpc::ByteBuffer& ret_msg) {
  VLOG(4) << "ProcGetRecvResponse";
  framework::Variable* outvar = nullptr;
  int trainer_id;
  DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                                &trainer_id);
}

146 147 148
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
149
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
150 151
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
152 153
}

154 155 156 157
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
158
                                     const std::string& out_varname,
Q
Qiao Longfei 已提交
159
                                     const std::string& table_name,
160
                                     int64_t time_out) {
161
  return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
Q
Qiao Longfei 已提交
162 163
                      "/sendrecv.SendRecvService/GetVariable", table_name,
                      time_out);
164 165
}

166 167 168 169 170 171 172 173 174
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    const std::string& out_varname, int64_t time_out) {
  std::string var_name_no_barrier =
      string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);

  return _AsyncGetVar(
      ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
Q
Qiao Longfei 已提交
175
      "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
176 177
}

178 179 180 181
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
182
  return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
Q
Qiao Longfei 已提交
183 184
                      "/sendrecv.SendRecvService/GetMonomerVariable", "",
                      time_out);
185 186
}

187 188 189 190
VarHandlePtr GRPCClient::_AsyncGetVar(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& method,
    const std::string& var_name, const std::string& out_varname,
Q
Qiao Longfei 已提交
191 192
    const std::string& rpc_path, const std::string& table_name,
    int64_t time_out) {
193 194 195
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
196
  const std::string out_varname_val = out_varname;
Q
Qiao Longfei 已提交
197
  const std::string table_name_val = table_name;
198
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
199
  const auto ch = GetChannel(ep_val);
200

201 202 203 204
  int retry_times_ = 0;

  while (true) {
    GetProcessor* s = new GetProcessor(ch);
205

206 207
    VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
208

209 210 211 212 213 214 215 216 217 218
    framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s,
                        method, p_ctx, h, rpc_path, this] {
      // prepare input
      sendrecv::VariableMessage req;
      req.set_varname(var_name_val);
      req.set_out_varname(out_varname_val);
      req.set_trainer_id(trainer_id_);
      req.set_table_name(table_name_val);
      ::grpc::ByteBuffer buf;
      RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
219

220
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
221

222 223
      // stub context
      s->response_call_back_ = ProcGetResponse;
G
gongweibao 已提交
224

225 226 227 228 229 230 231 232 233 234 235 236
      platform::RecordRPCEvent record_event(method);

      auto call =
          s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;
G
gongweibao 已提交
237

238
    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
Q
Qiao Longfei 已提交
239
      h->Wait();
240 241 242 243 244 245 246
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
Q
Qiao Longfei 已提交
247
    }
G
gongweibao 已提交
248

249 250
    return h;
  }
G
gongweibao 已提交
251 252
}

253 254 255 256 257
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
258
                                          const std::string& table_name,
259
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
260 261 262 263
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
264
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
265
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
266
  const auto ch = GetChannel(ep_val);
G
gongweibao 已提交
267

268
  const std::string method = kPrefetchRPC;
269
  int retry_times_ = 0;
G
gongweibao 已提交
270

271 272 273
  while (true) {
    GetProcessor* s = new GetProcessor(ch);
    VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
T
tangwei12 已提交
274
    s->Prepare(h, kPrefetchTimeout);
Q
Qiao Longfei 已提交
275

276 277 278
    framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
                        p_ctx, s, method, h, table_name_val, this] {
      auto* var = p_scope->FindVar(in_var_name_val);
Q
Qiao Longfei 已提交
279

280 281 282
      ::grpc::ByteBuffer req;
      SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req,
                            out_var_name_val, 0, table_name_val);
Q
Qiao Longfei 已提交
283

284
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
285

286 287
      // stub context
      s->response_call_back_ = ProcGetResponse;
Q
Qiao Longfei 已提交
288

289
      platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
290

291 292 293 294 295
      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
          &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
296

297 298 299 300 301 302 303
      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
G
gongweibao 已提交
304
      h->Wait();
305 306 307 308 309 310 311
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
G
gongweibao 已提交
312
    }
Q
Qiao Longfei 已提交
313

314 315
    return h;
  }
Q
Qiao Longfei 已提交
316 317
}

318 319
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
T
tangwei12 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = kBatchBarrierRPC;
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

341
  return h;
342
}
Y
Yancey 已提交
343

344 345
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
T
tangwei12 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
  const auto ch = GetChannel(ep);
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
  const std::string method = kFetchBarrierRPC;
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

366
  return h;
Y
Yancey 已提交
367 368
}

369 370 371 372 373
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
374
  const std::string method = kSendMonomerFetchBarrierRPC;
375
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
376 377 378 379 380 381 382
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

383
  platform::RecordRPCEvent record_event(method);
384 385 386 387 388 389 390 391 392 393 394 395

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

396 397
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
T
tangwei12 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = kSendCompleteRPC;
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_trainer_id(trainer_id_);
  req.set_varname(COMPLETE_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

419
  return h;
Y
Yancey1989 已提交
420 421
}

422
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
423 424
                                               const std::string& dirname,
                                               const std::string& varname,
425
                                               const int mode,
426
                                               int64_t time_out) {
T
tangwei12 已提交
427
  const auto ch = GetChannel(ep);
428

T
tangwei12 已提交
429
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
430

431
  const std::string method = kCheckPointNotifyRPC;
G
gongweibao 已提交
432 433 434

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
435
  s->Prepare(h, time_out);
T
tangwei12 已提交
436

437
  sendrecv::VariableMessage req;
438
  req.set_varname(varname);
439
  req.set_table_name(std::to_string(mode));
440
  req.set_out_varname(dirname);
T
tangwei12 已提交
441

442
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
443

T
bug fix  
tangwei12 已提交
444
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
445 446
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
447 448 449 450 451

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

452
  return h;
T
tangwei12 已提交
453 454
}

1
123malin 已提交
455 456 457 458 459 460 461 462 463
VarHandlePtr GRPCClient::AsyncDistributeNotify(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
  const auto ch = GetChannel(ep_val);
464 465
  const std::string method = kRequestNotify;

1
123malin 已提交
466 467
  SendProcessor* s = new SendProcessor(ch);
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
468 469
  s->Prepare(h, time_out);

1
123malin 已提交
470
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
T
tangwei12 已提交
471
    auto* var = p_scope->FindVar(var_name_val);
472

T
tangwei12 已提交
473 474
    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
475

1
123malin 已提交
476 477 478 479 480 481 482 483
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

    // stub context
    s->response_call_back_ = nullptr;

    platform::RecordRPCEvent record_event(method);

    auto call = s->stub_g_.PrepareUnaryCall(
T
tangwei12 已提交
484
        s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
1
123malin 已提交
485 486 487 488
        &cq_);
    call->StartCall();
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  });
489 490 491 492 493 494 495 496 497
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& send_var_name,
                                          const std::string& recv_var_name,
                                          const std::string& table_name,
                                          int64_t time_out) {
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string send_var_name_val = send_var_name;
  const std::string recv_var_name_val = recv_var_name;
  const std::string table_name_val = table_name;
  const framework::Scope* p_scope = &scope;
  const auto ch = GetChannel(ep_val);
  const std::string method = kSendAndRecvRPC;
  VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: "
          << send_var_name_val << " Recv_var_name: " << recv_var_name_val;
  int retry_times_ = 0;

  while (true) {
    SendAndRecvProcessor* s = new SendAndRecvProcessor(ch);
    VarHandlePtr h(
        new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope));
    VarHandlePtr h_recv(
        new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
    s->RecvPrepare(h_recv);

    framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val,
                        p_scope, p_ctx, s, method, h, this] {
      auto* send_var = p_scope->FindVar(send_var_name_val);
      send_var->GetMutable<framework::LoDTensor>()->set_lod({});
      ::grpc::ByteBuffer buf;
      VLOG(4) << "SerializeToByteBuffer: send_var_name_val: "
              << send_var_name_val
              << " recv_var_name_val: " << recv_var_name_val;
      SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf,
                            recv_var_name_val, trainer_id_, table_name_val);

      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

      // stub context
      s->response_call_back_ = ProcGetRecvResponse;

      platform::RecordRPCEvent record_event(method);

      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable",
          buf, &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
      h->Wait();
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
    }

    return h;
  }
}

Y
Yancey1989 已提交
571
bool GRPCClient::Wait() {
W
Wu Yi 已提交
572
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
573 574
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
575 576
}

T
tangwei12 已提交
577 578 579 580 581 582 583 584 585 586 587 588
inline bool ShouldRetry(const std::string& method, int error_code) {
  if (method == kPrefetchRPC) {
    return true;
  }

  if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) {
    return true;
  }

  return false;
}

G
gongweibao 已提交
589
void GRPCClient::Proceed() {
W
Wu Yi 已提交
590
  void* tag = nullptr;
G
gongweibao 已提交
591 592
  bool ok = false;

M
minqiyang 已提交
593
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
594
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
595 596
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
M
MRXLT 已提交
597 598
    PADDLE_ENFORCE_NOT_NULL(
        c, platform::errors::PreconditionNotMet("Make BaseProcessor failed."));
G
gongweibao 已提交
599

W
Wu Yi 已提交
600
    if (c->status_.ok()) {
M
minqiyang 已提交
601
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
602
      c->Process();
T
tangwei12 已提交
603 604 605
    } else if (ShouldRetry(c->GetVarHandlePtr()->method(),
                           c->status_.error_code())) {
      VLOG(0) << c->GetVarHandlePtr()->String()
606 607 608 609 610 611
              << " meets grpc error, error_code:" << c->status_.error_code()
              << " error_message:" << c->status_.error_message()
              << " error_details:" << c->status_.error_details()
              << " should retry!";
      c->GetVarHandlePtr()->should_retry = true;
      c->Finish(false);
W
Wu Yi 已提交
612
    } else {
613 614 615 616 617
      PADDLE_THROW(platform::errors::External(
          "%s meets grpc error, error_code is %d, error message is %s, error "
          "details is %s.",
          c->GetVarHandlePtr()->String(), c->status_.error_code(),
          c->status_.error_message(), c->status_.error_details()));
618
      c->Finish(false);
W
Wu Yi 已提交
619
    }
620

G
gongweibao 已提交
621
    bool notify = false;
W
Wu Yi 已提交
622 623 624
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
625 626 627 628 629 630 631
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
632
    }
G
gongweibao 已提交
633
  }
634 635 636 637 638 639 640

  // Last log message
  // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
  // static Mutex log_mutex is used for synchronization, which might have been
  // destructed at this moment.
  if (FLAGS_v >= 3) {
    std::string msg("GRPCClient Proceed end");
641
    fwrite(msg.c_str(), msg.length(), 1, stderr);
642
  }
G
gongweibao 已提交
643
}
W
Wu Yi 已提交
644

G
gongweibao 已提交
645
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
646
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
647
  auto it = channels_.find(ep);
G
gongweibao 已提交
648 649 650 651
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
652
  // Channel configurations:
G
gongweibao 已提交
653
  grpc::ChannelArguments args;
W
Wu Yi 已提交
654
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
655 656 657
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
658
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
659 660 661
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
662 663
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
664
  channels_[ep] = ch;
G
gongweibao 已提交
665 666 667
  return ch;
}

668
}  // namespace distributed
G
gongweibao 已提交
669 670
}  // namespace operators
}  // namespace paddle