grpc_client.cc 14.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include <limits>

G
gongweibao 已提交
17
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
18
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
19
#include "paddle/fluid/operators/distributed/grpc_client.h"
20
#include "paddle/fluid/operators/distributed/grpc_serde.h"
21
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
22
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

25 26
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
27 28
namespace paddle {
namespace operators {
29
namespace distributed {
G
gongweibao 已提交
30

31
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
32 33
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
34 35
  PADDLE_ENFORCE(client_thread_ == nullptr,
                 "please not re init proceed thread");
G
gongweibao 已提交
36
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
37 38
}

Y
Yancey1989 已提交
39
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
40 41 42
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
43
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
44 45 46 47
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
48 49 50
  }
}

G
gongweibao 已提交
51
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
52
  stopped_ = true;
W
Wu Yi 已提交
53 54 55 56 57 58 59
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
60
    channels_.clear();
W
Wu Yi 已提交
61 62
  }
  client_thread_->join();
Y
Yancey1989 已提交
63 64
}

65 66 67 68 69
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
70 71 72 73
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
74
  const auto ch = GetChannel(ep_val);
75
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
76 77
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
78
  s->Prepare(h, time_out);
79

G
gongweibao 已提交
80
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
81
    auto* var = p_scope->FindVar(var_name_val);
82 83

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
84
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
85

M
minqiyang 已提交
86
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
87 88

    // stub context
T
typhoonzero 已提交
89
    s->response_call_back_ = nullptr;
90

G
gongweibao 已提交
91
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
92

93 94
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
95
    call->StartCall();
Y
Yi Wang 已提交
96
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
97 98 99 100

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
101
  });
G
gongweibao 已提交
102 103
  req_count_++;

104
  return h;
G
gongweibao 已提交
105 106 107
}

void ProcGetResponse(const VarHandle& var_h,
108
                     const ::grpc::ByteBuffer& ret_msg) {
109
  VLOG(100) << "ProcGetResponse";
T
typhoonzero 已提交
110
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
111 112 113 114
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
115 116 117 118 119
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
120
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
121 122
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
123 124
}

125 126 127 128 129
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  return _AsyncGetVar(ep, ctx, scope, var_name,
                      "/sendrecv.SendRecvService/GetVariable", time_out);
}

VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
  return _AsyncGetVar(ep, ctx, scope, var_name,
                      "/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}

VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      const std::string& rpc_path,
                                      int64_t time_out) {
148 149 150 151
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
152
  const auto ch = GetChannel(ep_val);
153
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
154 155
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
156
  s->Prepare(h, time_out);
157

158
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
Q
Qiao Longfei 已提交
159
    // prepare input
160 161
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
162
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
163 164
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
165

M
minqiyang 已提交
166
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
167 168 169 170

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
171
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
172

173 174
    auto call =
        s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
175
    call->StartCall();
Y
Yi Wang 已提交
176
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
177 178 179 180

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
181
  });
G
gongweibao 已提交
182 183 184

  req_count_++;

185
  return h;
G
gongweibao 已提交
186 187
}

188 189 190 191 192
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
193
                                          const std::string& table_name,
194
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
195 196 197 198
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
199
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
200
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
201
  const auto ch = GetChannel(ep_val);
202
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
203 204 205 206

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
207
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
208

T
wip  
typhoonzero 已提交
209
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
210
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
211 212 213
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
214
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
215
                          0, table_name_val);
Q
Qiao Longfei 已提交
216

M
minqiyang 已提交
217
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
218 219 220 221

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
222
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
223

Q
Qiao Longfei 已提交
224
    auto call = s->stub_g_.PrepareUnaryCall(
225 226
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
227
    call->StartCall();
228
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
229 230 231 232

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
233 234 235
  });

  req_count_++;
236
  return h;
Q
Qiao Longfei 已提交
237 238
}

239 240
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
241
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
242 243

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
244 245 246
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
247
  s->Prepare(h, time_out);
Y
Yancey 已提交
248 249 250

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
251

G
gongweibao 已提交
252
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
253

Y
Yancey 已提交
254
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
255
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
256
  req_count_++;
G
gongweibao 已提交
257 258 259 260 261

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

262
  return h;
263
}
Y
Yancey 已提交
264

265 266
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
267
  const auto ch = GetChannel(ep);
268
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
269 270 271
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
272
  s->Prepare(h, time_out);
273 274 275

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
276

G
gongweibao 已提交
277
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
278

279
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
280
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
281
  req_count_++;
G
gongweibao 已提交
282 283 284 285 286

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

287
  return h;
Y
Yancey 已提交
288 289
}

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = "SendMonomerFetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

  platform::RecordRPCEvent record_event(method, nullptr);

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

318 319
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
320 321 322
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
323 324
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
325
  s->Prepare(h, time_out);
W
Wu Yi 已提交
326 327

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
328
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
329

G
gongweibao 已提交
330
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
331

W
Wu Yi 已提交
332 333
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
334
  req_count_++;
G
gongweibao 已提交
335 336 337 338 339

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

340
  return h;
Y
Yancey1989 已提交
341 342
}

343 344 345
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
346
  const auto ch = GetChannel(ep);
347

T
tangwei12 已提交
348
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
349 350 351 352 353

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
354
  s->Prepare(h, time_out);
T
tangwei12 已提交
355

356 357
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
358
  req.set_out_varname(dir);
T
tangwei12 已提交
359

G
gongweibao 已提交
360
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
361

T
bug fix  
tangwei12 已提交
362
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
363 364
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
365 366 367 368 369

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

370
  return h;
T
tangwei12 已提交
371 372
}

Y
Yancey1989 已提交
373
bool GRPCClient::Wait() {
W
Wu Yi 已提交
374
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
375 376
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
377 378
}

G
gongweibao 已提交
379
void GRPCClient::Proceed() {
W
Wu Yi 已提交
380
  void* tag = nullptr;
G
gongweibao 已提交
381 382
  bool ok = false;

M
minqiyang 已提交
383
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
384
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
385 386 387
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
388

W
Wu Yi 已提交
389
    if (c->status_.ok()) {
M
minqiyang 已提交
390
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
391
      c->Process();
Y
Yancey1989 已提交
392
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
393
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
394 395 396
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
397 398 399 400
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
401
      c->Finish(false);
W
Wu Yi 已提交
402
    } else {
403
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
404 405 406 407
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

408
      c->Finish(false);
W
Wu Yi 已提交
409
    }
410

G
gongweibao 已提交
411
    bool notify = false;
W
Wu Yi 已提交
412 413 414
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
415 416 417 418 419 420 421
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
422
    }
G
gongweibao 已提交
423
  }
M
minqiyang 已提交
424
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
425
}
W
Wu Yi 已提交
426

G
gongweibao 已提交
427
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
428
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
429
  auto it = channels_.find(ep);
G
gongweibao 已提交
430 431 432 433
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
434
  // Channel configurations:
G
gongweibao 已提交
435
  grpc::ChannelArguments args;
W
Wu Yi 已提交
436
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
437 438 439
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
440
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
441 442 443
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
444 445
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
446
  channels_[ep] = ch;
G
gongweibao 已提交
447 448 449
  return ch;
}

450
}  // namespace distributed
G
gongweibao 已提交
451 452
}  // namespace operators
}  // namespace paddle