grpc_client.cc 21.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
20 21
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
23
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25

1
123malin 已提交
26
DEFINE_int32(rpc_client_threads, 2, "");
27 28
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
29 30
namespace paddle {
namespace operators {
31
namespace distributed {
G
gongweibao 已提交
32

33
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
34 35
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
1
123malin 已提交
36 37 38 39 40
  client_threads_.resize(FLAGS_rpc_client_threads);
  for (int i = 0; i < FLAGS_rpc_client_threads; i++) {
    client_threads_[i].reset(
        new std::thread(std::bind(&GRPCClient::Proceed, this)));
  }
W
Wu Yi 已提交
41 42
}

Y
Yancey1989 已提交
43
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
44 45 46
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
47
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
48 49
      this->AsyncSendComplete(it.first);
    }
M
MRXLT 已提交
50 51
    PADDLE_ENFORCE_EQ(this->Wait(), true, platform::errors::PreconditionNotMet(
                                              "internal grpc service error."));
Y
Yancey1989 已提交
52
    completed_ = true;
W
Wu Yi 已提交
53 54 55
  }
}

G
gongweibao 已提交
56
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
57
  stopped_ = true;
W
Wu Yi 已提交
58 59 60 61 62 63 64
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
65
    channels_.clear();
W
Wu Yi 已提交
66
  }
1
123malin 已提交
67 68
  for (size_t i = 0; i < client_threads_.size(); i++)
    client_threads_[i]->join();
Y
Yancey1989 已提交
69 70
}

71 72 73 74 75
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
76 77 78 79
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
80
  const auto ch = GetChannel(ep_val);
81
  const std::string method = kSendRPC;
82

83 84 85 86 87 88
  int retry_times_ = 0;

  while (true) {
    SendProcessor* s = new SendProcessor(ch);
    VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
89

1
123malin 已提交
90
    framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
91
      auto* var = p_scope->FindVar(var_name_val);
92

93 94
      ::grpc::ByteBuffer req;
      SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
95

96
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
97

98 99
      // stub context
      s->response_call_back_ = nullptr;
G
gongweibao 已提交
100

101
      platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
102

103 104 105 106 107 108 109 110 111 112 113 114 115
      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
          &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
G
gongweibao 已提交
116
      h->Wait();
117 118 119 120 121 122 123
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
G
gongweibao 已提交
124
    }
G
gongweibao 已提交
125

126 127
    return h;
  }
G
gongweibao 已提交
128 129 130
}

void ProcGetResponse(const VarHandle& var_h,
131
                     const ::grpc::ByteBuffer& ret_msg) {
132
  VLOG(4) << "ProcGetResponse";
T
typhoonzero 已提交
133
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
134 135 136 137
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
138 139
}

140 141 142 143 144 145 146 147 148
void ProcGetRecvResponse(const VarHandle& var_h,
                         const ::grpc::ByteBuffer& ret_msg) {
  VLOG(4) << "ProcGetRecvResponse";
  framework::Variable* outvar = nullptr;
  int trainer_id;
  DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                                &trainer_id);
}

149 150 151
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
152
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
153 154
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
155 156
}

157 158 159 160
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
161
                                     const std::string& out_varname,
Q
Qiao Longfei 已提交
162
                                     const std::string& table_name,
163
                                     int64_t time_out) {
164
  return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
Q
Qiao Longfei 已提交
165 166
                      "/sendrecv.SendRecvService/GetVariable", table_name,
                      time_out);
167 168
}

169 170 171 172 173 174 175 176 177
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    const std::string& out_varname, int64_t time_out) {
  std::string var_name_no_barrier =
      string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);

  return _AsyncGetVar(
      ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
Q
Qiao Longfei 已提交
178
      "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out);
179 180
}

181 182 183 184
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
185
  return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
Q
Qiao Longfei 已提交
186 187
                      "/sendrecv.SendRecvService/GetMonomerVariable", "",
                      time_out);
188 189
}

190 191 192 193
VarHandlePtr GRPCClient::_AsyncGetVar(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& method,
    const std::string& var_name, const std::string& out_varname,
Q
Qiao Longfei 已提交
194 195
    const std::string& rpc_path, const std::string& table_name,
    int64_t time_out) {
196 197 198
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
199
  const std::string out_varname_val = out_varname;
Q
Qiao Longfei 已提交
200
  const std::string table_name_val = table_name;
201
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
202
  const auto ch = GetChannel(ep_val);
203

204 205 206 207
  int retry_times_ = 0;

  while (true) {
    GetProcessor* s = new GetProcessor(ch);
208

209 210
    VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
211

1
123malin 已提交
212 213
    framework::Async([var_name_val, out_varname_val, table_name_val, s, method,
                      p_ctx, h, rpc_path, this] {
214 215 216 217 218 219 220 221
      // prepare input
      sendrecv::VariableMessage req;
      req.set_varname(var_name_val);
      req.set_out_varname(out_varname_val);
      req.set_trainer_id(trainer_id_);
      req.set_table_name(table_name_val);
      ::grpc::ByteBuffer buf;
      RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
222

223
      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
224

225 226
      // stub context
      s->response_call_back_ = ProcGetResponse;
G
gongweibao 已提交
227

228 229 230 231 232 233 234 235 236 237 238 239
      platform::RecordRPCEvent record_event(method);

      auto call =
          s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;
G
gongweibao 已提交
240

241
    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
Q
Qiao Longfei 已提交
242
      h->Wait();
243 244 245 246 247 248 249
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
Q
Qiao Longfei 已提交
250
    }
G
gongweibao 已提交
251

252 253
    return h;
  }
G
gongweibao 已提交
254 255
}

256 257 258 259 260
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
261
                                          const std::string& table_name,
262
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
263 264 265 266
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
267
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
268
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
269
  const auto ch = GetChannel(ep_val);
G
gongweibao 已提交
270

271
  const std::string method = kPrefetchRPC;
272
  int retry_times_ = 0;
G
gongweibao 已提交
273

274 275 276
  while (true) {
    GetProcessor* s = new GetProcessor(ch);
    VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
T
tangwei12 已提交
277
    s->Prepare(h, kPrefetchTimeout);
Q
Qiao Longfei 已提交
278

1
123malin 已提交
279
    auto* var = p_scope->FindVar(in_var_name_val);
Q
Qiao Longfei 已提交
280

1
123malin 已提交
281 282 283
    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
                          0, table_name_val);
Q
Qiao Longfei 已提交
284

1
123malin 已提交
285
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
286

1
123malin 已提交
287 288
    // stub context
    s->response_call_back_ = ProcGetResponse;
Q
Qiao Longfei 已提交
289

1
123malin 已提交
290
    platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
291

1
123malin 已提交
292 293 294 295 296 297 298 299 300
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
    call->StartCall();
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
G
gongweibao 已提交
301

302 303 304
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
G
gongweibao 已提交
305
      h->Wait();
306 307 308 309 310 311 312
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
G
gongweibao 已提交
313
    }
Q
Qiao Longfei 已提交
314

315 316
    return h;
  }
Q
Qiao Longfei 已提交
317 318
}

319 320
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
T
tangwei12 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = kBatchBarrierRPC;
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

342
  return h;
343
}
Y
Yancey 已提交
344

345 346
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
T
tangwei12 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
  const auto ch = GetChannel(ep);
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
  const std::string method = kFetchBarrierRPC;
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

367
  return h;
Y
Yancey 已提交
368 369
}

370 371 372 373 374
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
375
  const std::string method = kSendMonomerFetchBarrierRPC;
376
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
377 378 379 380 381 382 383
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

384
  platform::RecordRPCEvent record_event(method);
385 386 387 388 389 390 391 392 393 394 395 396

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

397 398
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
T
tangwei12 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = kSendCompleteRPC;
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  sendrecv::VariableMessage req;
  req.set_trainer_id(trainer_id_);
  req.set_varname(COMPLETE_MESSAGE);

  platform::RecordRPCEvent record_event(method);

  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

420
  return h;
Y
Yancey1989 已提交
421 422
}

423
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
424 425
                                               const std::string& dirname,
                                               const std::string& varname,
426
                                               const int mode,
427
                                               int64_t time_out) {
T
tangwei12 已提交
428
  const auto ch = GetChannel(ep);
429

T
tangwei12 已提交
430
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
431

432
  const std::string method = kCheckPointNotifyRPC;
G
gongweibao 已提交
433 434 435

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
436
  s->Prepare(h, time_out);
T
tangwei12 已提交
437

438
  sendrecv::VariableMessage req;
439
  req.set_varname(varname);
440
  req.set_table_name(std::to_string(mode));
441
  req.set_out_varname(dirname);
T
tangwei12 已提交
442

443
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
444

T
bug fix  
tangwei12 已提交
445
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
446 447
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
448 449 450 451 452

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

453
  return h;
T
tangwei12 已提交
454 455
}

1
123malin 已提交
456 457 458 459 460 461 462 463 464
VarHandlePtr GRPCClient::AsyncDistributeNotify(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
  const auto ch = GetChannel(ep_val);
465 466
  const std::string method = kRequestNotify;

1
123malin 已提交
467 468
  SendProcessor* s = new SendProcessor(ch);
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
469 470
  s->Prepare(h, time_out);

1
123malin 已提交
471
  framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
T
tangwei12 已提交
472
    auto* var = p_scope->FindVar(var_name_val);
473

T
tangwei12 已提交
474 475
    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
476

1
123malin 已提交
477 478 479 480 481 482 483 484
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

    // stub context
    s->response_call_back_ = nullptr;

    platform::RecordRPCEvent record_event(method);

    auto call = s->stub_g_.PrepareUnaryCall(
T
tangwei12 已提交
485
        s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
1
123malin 已提交
486 487 488 489
        &cq_);
    call->StartCall();
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  });
490 491 492 493 494 495 496 497 498
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& send_var_name,
                                          const std::string& recv_var_name,
                                          const std::string& table_name,
                                          int64_t time_out) {
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string send_var_name_val = send_var_name;
  const std::string recv_var_name_val = recv_var_name;
  const std::string table_name_val = table_name;
  const framework::Scope* p_scope = &scope;
  const auto ch = GetChannel(ep_val);
  const std::string method = kSendAndRecvRPC;
  VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: "
          << send_var_name_val << " Recv_var_name: " << recv_var_name_val;
  int retry_times_ = 0;

  while (true) {
    SendAndRecvProcessor* s = new SendAndRecvProcessor(ch);
    VarHandlePtr h(
        new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope));
    VarHandlePtr h_recv(
        new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope));
    s->Prepare(h, time_out);
    s->RecvPrepare(h_recv);

1
123malin 已提交
527 528
    framework::Async([send_var_name_val, recv_var_name_val, table_name_val,
                      p_scope, p_ctx, s, method, h, this] {
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
      auto* send_var = p_scope->FindVar(send_var_name_val);
      send_var->GetMutable<framework::LoDTensor>()->set_lod({});
      ::grpc::ByteBuffer buf;
      VLOG(4) << "SerializeToByteBuffer: send_var_name_val: "
              << send_var_name_val
              << " recv_var_name_val: " << recv_var_name_val;
      SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf,
                            recv_var_name_val, trainer_id_, table_name_val);

      VLOG(3) << s->GetVarHandlePtr()->String() << " begin";

      // stub context
      s->response_call_back_ = ProcGetRecvResponse;

      platform::RecordRPCEvent record_event(method);

      auto call = s->stub_g_.PrepareUnaryCall(
          s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable",
          buf, &cq_);
      call->StartCall();
      call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));

      if (UNLIKELY(platform::IsProfileEnabled())) {
        h->Wait();
      }
    });
    req_count_++;

    if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
      h->Wait();
      if (h->should_retry) {
        VLOG(3) << "rpc call failed, retry times " << retry_times_;
        retry_times_++;
        std::random_device rd;
        std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
        continue;
      }
    }

    return h;
  }
}

Y
Yancey1989 已提交
572
bool GRPCClient::Wait() {
W
Wu Yi 已提交
573
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
574 575
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
576 577
}

T
tangwei12 已提交
578 579 580 581 582 583 584 585 586 587 588 589
inline bool ShouldRetry(const std::string& method, int error_code) {
  if (method == kPrefetchRPC) {
    return true;
  }

  if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) {
    return true;
  }

  return false;
}

G
gongweibao 已提交
590
void GRPCClient::Proceed() {
W
Wu Yi 已提交
591
  void* tag = nullptr;
G
gongweibao 已提交
592 593
  bool ok = false;

M
minqiyang 已提交
594
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
595
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
596 597
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
M
MRXLT 已提交
598 599
    PADDLE_ENFORCE_NOT_NULL(
        c, platform::errors::PreconditionNotMet("Make BaseProcessor failed."));
G
gongweibao 已提交
600

W
Wu Yi 已提交
601
    if (c->status_.ok()) {
M
minqiyang 已提交
602
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
603
      c->Process();
T
tangwei12 已提交
604 605 606
    } else if (ShouldRetry(c->GetVarHandlePtr()->method(),
                           c->status_.error_code())) {
      VLOG(0) << c->GetVarHandlePtr()->String()
607 608 609 610 611 612
              << " meets grpc error, error_code:" << c->status_.error_code()
              << " error_message:" << c->status_.error_message()
              << " error_details:" << c->status_.error_details()
              << " should retry!";
      c->GetVarHandlePtr()->should_retry = true;
      c->Finish(false);
W
Wu Yi 已提交
613
    } else {
614 615 616 617 618
      PADDLE_THROW(platform::errors::External(
          "%s meets grpc error, error_code is %d, error message is %s, error "
          "details is %s.",
          c->GetVarHandlePtr()->String(), c->status_.error_code(),
          c->status_.error_message(), c->status_.error_details()));
619
      c->Finish(false);
W
Wu Yi 已提交
620
    }
621

G
gongweibao 已提交
622
    bool notify = false;
W
Wu Yi 已提交
623 624 625
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
626 627 628 629 630 631 632
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
633
    }
G
gongweibao 已提交
634
  }
635 636 637 638 639 640 641

  // Last log message
  // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
  // static Mutex log_mutex is used for synchronization, which might have been
  // destructed at this moment.
  if (FLAGS_v >= 3) {
    std::string msg("GRPCClient Proceed end");
642
    fwrite(msg.c_str(), msg.length(), 1, stderr);
643
  }
G
gongweibao 已提交
644
}
W
Wu Yi 已提交
645

G
gongweibao 已提交
646
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
647
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
648
  auto it = channels_.find(ep);
G
gongweibao 已提交
649 650 651 652
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
653
  // Channel configurations:
G
gongweibao 已提交
654
  grpc::ChannelArguments args;
W
Wu Yi 已提交
655
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
656 657 658
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
659
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
660 661 662
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
663 664
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
665
  channels_[ep] = ch;
G
gongweibao 已提交
666 667 668
  return ch;
}

669
}  // namespace distributed
G
gongweibao 已提交
670 671
}  // namespace operators
}  // namespace paddle