grpc_client.cc 14.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include <limits>

G
gongweibao 已提交
17
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
18
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
19
#include "paddle/fluid/operators/distributed/grpc_client.h"
20
#include "paddle/fluid/operators/distributed/grpc_serde.h"
21
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
22
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

25 26
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
27 28
namespace paddle {
namespace operators {
29
namespace distributed {
G
gongweibao 已提交
30

31
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
32 33
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
34 35
  PADDLE_ENFORCE(client_thread_ == nullptr,
                 "please not re init proceed thread");
G
gongweibao 已提交
36
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
37 38
}

Y
Yancey1989 已提交
39
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
40 41 42
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
43
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
44 45 46 47
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
48 49 50
  }
}

G
gongweibao 已提交
51
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
52
  stopped_ = true;
W
Wu Yi 已提交
53 54 55 56 57 58 59
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
60
    channels_.clear();
W
Wu Yi 已提交
61 62
  }
  client_thread_->join();
Y
Yancey1989 已提交
63 64
}

65 66 67 68 69
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
70 71 72 73
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
74
  const auto ch = GetChannel(ep_val);
75
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
76 77
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
78
  s->Prepare(h, time_out);
79

G
gongweibao 已提交
80
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
81
    auto* var = p_scope->FindVar(var_name_val);
82 83

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
84
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
85

M
minqiyang 已提交
86
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
87 88

    // stub context
T
typhoonzero 已提交
89
    s->response_call_back_ = nullptr;
90

G
gongweibao 已提交
91
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
92

93 94
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
95
    call->StartCall();
Y
Yi Wang 已提交
96
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
97 98 99 100

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
101
  });
G
gongweibao 已提交
102 103
  req_count_++;

104
  return h;
G
gongweibao 已提交
105 106 107
}

void ProcGetResponse(const VarHandle& var_h,
108
                     const ::grpc::ByteBuffer& ret_msg) {
109
  VLOG(100) << "ProcGetResponse";
T
typhoonzero 已提交
110
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
111 112 113 114
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
115 116 117 118 119
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
120
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
121 122
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
123 124
}

125 126 127 128 129
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  return _AsyncGetVar(ep, ctx, scope, var_name,
                      "/sendrecv.SendRecvService/GetVariable", time_out);
}

VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
  return _AsyncGetVar(ep, ctx, scope, var_name,
                      "/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}

VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      const std::string& rpc_path,
                                      int64_t time_out) {
148 149 150 151
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
152
  const auto ch = GetChannel(ep_val);
153
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
154 155
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
156
  s->Prepare(h, time_out);
157

158
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
Q
Qiao Longfei 已提交
159
    // prepare input
160 161
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
162
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
163 164
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
165

M
minqiyang 已提交
166
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
167 168 169 170

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
171
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
172

173 174
    auto call =
        s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
175
    call->StartCall();
Y
Yi Wang 已提交
176
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
177 178 179 180

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
181
  });
G
gongweibao 已提交
182 183 184

  req_count_++;

185
  return h;
G
gongweibao 已提交
186 187
}

188 189 190 191 192
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
193
                                          const std::string& table_name,
194
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
195 196 197 198
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
199
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
200
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
201
  const auto ch = GetChannel(ep_val);
202
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
203 204 205 206

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
207
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
208

T
wip  
typhoonzero 已提交
209
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
210
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
211 212 213
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
214
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
215
                          0, table_name_val);
Q
Qiao Longfei 已提交
216

M
minqiyang 已提交
217
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
218 219 220 221

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
222
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
223

Q
Qiao Longfei 已提交
224
    auto call = s->stub_g_.PrepareUnaryCall(
225 226
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
227
    call->StartCall();
228
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
229 230 231 232

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
233 234 235
  });

  req_count_++;
236
  return h;
Q
Qiao Longfei 已提交
237 238
}

239 240
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
241
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
242 243

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
244 245 246
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
247
  s->Prepare(h, time_out);
Y
Yancey 已提交
248 249 250

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
251

G
gongweibao 已提交
252
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
253

Y
Yancey 已提交
254
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
255
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
256
  req_count_++;
G
gongweibao 已提交
257 258 259 260 261

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

262
  return h;
263
}
Y
Yancey 已提交
264

265 266
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
267
  const auto ch = GetChannel(ep);
268
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
269 270 271
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
272
  s->Prepare(h, time_out);
273 274 275

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
276

G
gongweibao 已提交
277
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
278

279
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
280
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
281
  req_count_++;
G
gongweibao 已提交
282 283 284 285 286

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

287
  return h;
Y
Yancey 已提交
288 289
}

290 291 292 293 294 295
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = "SendMonomerFetchBarrierRPC";
296
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

  platform::RecordRPCEvent record_event(method, nullptr);

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

317 318
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
319 320 321
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
322 323
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
324
  s->Prepare(h, time_out);
W
Wu Yi 已提交
325 326

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
327
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
328

G
gongweibao 已提交
329
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
330

W
Wu Yi 已提交
331 332
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
333
  req_count_++;
G
gongweibao 已提交
334 335 336 337 338

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

339
  return h;
Y
Yancey1989 已提交
340 341
}

342 343 344
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
345
  const auto ch = GetChannel(ep);
346

T
tangwei12 已提交
347
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
348 349 350 351 352

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
353
  s->Prepare(h, time_out);
T
tangwei12 已提交
354

355 356
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
357
  req.set_out_varname(dir);
T
tangwei12 已提交
358

G
gongweibao 已提交
359
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
360

T
bug fix  
tangwei12 已提交
361
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
362 363
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
364 365 366 367 368

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

369
  return h;
T
tangwei12 已提交
370 371
}

Y
Yancey1989 已提交
372
bool GRPCClient::Wait() {
W
Wu Yi 已提交
373
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
374 375
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
376 377
}

G
gongweibao 已提交
378
void GRPCClient::Proceed() {
W
Wu Yi 已提交
379
  void* tag = nullptr;
G
gongweibao 已提交
380 381
  bool ok = false;

M
minqiyang 已提交
382
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
383
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
384 385 386
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
387

W
Wu Yi 已提交
388
    if (c->status_.ok()) {
M
minqiyang 已提交
389
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
390
      c->Process();
Y
Yancey1989 已提交
391
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
392
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
393 394 395
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
396 397 398 399
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
400
      c->Finish(false);
W
Wu Yi 已提交
401
    } else {
402
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
403 404 405 406
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

407
      c->Finish(false);
W
Wu Yi 已提交
408
    }
409

G
gongweibao 已提交
410
    bool notify = false;
W
Wu Yi 已提交
411 412 413
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
414 415 416 417 418 419 420
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
421
    }
G
gongweibao 已提交
422
  }
M
minqiyang 已提交
423
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
424
}
W
Wu Yi 已提交
425

G
gongweibao 已提交
426
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
427
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
428
  auto it = channels_.find(ep);
G
gongweibao 已提交
429 430 431 432
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
433
  // Channel configurations:
G
gongweibao 已提交
434
  grpc::ChannelArguments args;
W
Wu Yi 已提交
435
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
436 437 438
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
439
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
440 441 442
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
443 444
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
445
  channels_[ep] = ch;
G
gongweibao 已提交
446 447 448
  return ch;
}

449
}  // namespace distributed
G
gongweibao 已提交
450 451
}  // namespace operators
}  // namespace paddle