grpc_client.cc 14.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
20
#include "paddle/fluid/operators/distributed/grpc_client.h"
21
#include "paddle/fluid/operators/distributed/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
23
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25

26 27
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
28 29
namespace paddle {
namespace operators {
30
namespace distributed {
G
gongweibao 已提交
31

32
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
33 34
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
35 36
  PADDLE_ENFORCE(client_thread_ == nullptr,
                 "please not re init proceed thread");
G
gongweibao 已提交
37
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
38 39
}

Y
Yancey1989 已提交
40
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
41 42 43
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
44
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
45 46 47 48
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
49 50 51
  }
}

G
gongweibao 已提交
52
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
53
  stopped_ = true;
W
Wu Yi 已提交
54 55 56 57 58 59 60
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
61
    channels_.clear();
W
Wu Yi 已提交
62 63
  }
  client_thread_->join();
Y
Yancey1989 已提交
64 65
}

66 67 68 69 70
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
71 72 73 74
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
75
  const auto ch = GetChannel(ep_val);
76
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
77 78
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
79
  s->Prepare(h, time_out);
80

G
gongweibao 已提交
81
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
82
    auto* var = p_scope->FindVar(var_name_val);
83 84

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
85
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
86

M
minqiyang 已提交
87
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
88 89

    // stub context
T
typhoonzero 已提交
90
    s->response_call_back_ = nullptr;
91

G
gongweibao 已提交
92
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
93

94 95
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
96
    call->StartCall();
Y
Yi Wang 已提交
97
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
98 99 100 101

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
102
  });
G
gongweibao 已提交
103 104
  req_count_++;

105
  return h;
G
gongweibao 已提交
106 107 108
}

void ProcGetResponse(const VarHandle& var_h,
109
                     const ::grpc::ByteBuffer& ret_msg) {
110
  VLOG(100) << "ProcGetResponse";
T
typhoonzero 已提交
111
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
112 113 114 115
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
116 117 118 119 120
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
121
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
122 123
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
124 125
}

126 127 128 129 130
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
  return _AsyncGetVar(ep, ctx, scope, var_name,
                      "/sendrecv.SendRecvService/GetVariable", time_out);
}

VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
  return _AsyncGetVar(ep, ctx, scope, var_name,
                      "/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}

VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      const std::string& rpc_path,
                                      int64_t time_out) {
149 150 151 152
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
153
  const auto ch = GetChannel(ep_val);
154
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
155 156
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
157
  s->Prepare(h, time_out);
158

159
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
Q
Qiao Longfei 已提交
160
    // prepare input
161 162
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
163
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
164 165
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
166

M
minqiyang 已提交
167
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
168 169 170 171

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
172
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
173

174 175
    auto call =
        s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
176
    call->StartCall();
Y
Yi Wang 已提交
177
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
178 179 180 181

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
182
  });
G
gongweibao 已提交
183 184 185

  req_count_++;

186
  return h;
G
gongweibao 已提交
187 188
}

189 190 191 192 193
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
194
                                          const std::string& table_name,
195
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
196 197 198 199
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
200
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
201
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
202
  const auto ch = GetChannel(ep_val);
203
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
204 205 206 207

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
208
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
209

T
wip  
typhoonzero 已提交
210
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
211
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
212 213 214
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
215
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
216
                          0, table_name_val);
Q
Qiao Longfei 已提交
217

M
minqiyang 已提交
218
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
219 220 221 222

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
223
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
224

Q
Qiao Longfei 已提交
225
    auto call = s->stub_g_.PrepareUnaryCall(
226 227
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
228
    call->StartCall();
229
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
230 231 232 233

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
234 235 236
  });

  req_count_++;
237
  return h;
Q
Qiao Longfei 已提交
238 239
}

240 241
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
242
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
243 244

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
245 246 247
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
248
  s->Prepare(h, time_out);
Y
Yancey 已提交
249 250 251

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
252

G
gongweibao 已提交
253
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
254

Y
Yancey 已提交
255
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
256
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
257
  req_count_++;
G
gongweibao 已提交
258 259 260 261 262

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

263
  return h;
264
}
Y
Yancey 已提交
265

266 267
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
268
  const auto ch = GetChannel(ep);
269
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
270 271 272
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
273
  s->Prepare(h, time_out);
274 275 276

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
277

G
gongweibao 已提交
278
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
279

280
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
281
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
282
  req_count_++;
G
gongweibao 已提交
283 284 285 286 287

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

288
  return h;
Y
Yancey 已提交
289 290
}

291 292 293 294 295 296
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
  const std::string method = "SendMonomerFetchBarrierRPC";
297
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

  platform::RecordRPCEvent record_event(method, nullptr);

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

318 319
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
320 321 322
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
323 324
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
325
  s->Prepare(h, time_out);
W
Wu Yi 已提交
326 327

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
328
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
329

G
gongweibao 已提交
330
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
331

W
Wu Yi 已提交
332 333
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
334
  req_count_++;
G
gongweibao 已提交
335 336 337 338 339

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

340
  return h;
Y
Yancey1989 已提交
341 342
}

343 344 345
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
346
  const auto ch = GetChannel(ep);
347

T
tangwei12 已提交
348
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
349 350 351 352 353

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
354
  s->Prepare(h, time_out);
T
tangwei12 已提交
355

356 357
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
358
  req.set_out_varname(dir);
T
tangwei12 已提交
359

G
gongweibao 已提交
360
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
361

T
bug fix  
tangwei12 已提交
362
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
363 364
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
365 366 367 368 369

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

370
  return h;
T
tangwei12 已提交
371 372
}

Y
Yancey1989 已提交
373
bool GRPCClient::Wait() {
W
Wu Yi 已提交
374
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
375 376
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
377 378
}

G
gongweibao 已提交
379
void GRPCClient::Proceed() {
W
Wu Yi 已提交
380
  void* tag = nullptr;
G
gongweibao 已提交
381 382
  bool ok = false;

M
minqiyang 已提交
383
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
384
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
385 386 387
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
388

W
Wu Yi 已提交
389
    if (c->status_.ok()) {
M
minqiyang 已提交
390
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
391
      c->Process();
Y
Yancey1989 已提交
392
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
393
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
394 395 396
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
397 398 399 400
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
401
      c->Finish(false);
W
Wu Yi 已提交
402
    } else {
403
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
404 405 406 407
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

408
      c->Finish(false);
W
Wu Yi 已提交
409
    }
410

G
gongweibao 已提交
411
    bool notify = false;
W
Wu Yi 已提交
412 413 414
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
415 416 417 418 419 420 421
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
422
    }
G
gongweibao 已提交
423
  }
424 425 426 427 428 429 430 431 432

  // Last log message
  // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
  // static Mutex log_mutex is used for synchronization, which might have been
  // destructed at this moment.
  if (FLAGS_v >= 3) {
    std::string msg("GRPCClient Proceed end");
    fwrite(msg.c_str(), msg.length(), 1, stdout);
  }
G
gongweibao 已提交
433
}
W
Wu Yi 已提交
434

G
gongweibao 已提交
435
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
436
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
437
  auto it = channels_.find(ep);
G
gongweibao 已提交
438 439 440 441
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
442
  // Channel configurations:
G
gongweibao 已提交
443
  grpc::ChannelArguments args;
W
Wu Yi 已提交
444
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
445 446 447
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
448
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
449 450 451
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
452 453
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
454
  channels_[ep] = ch;
G
gongweibao 已提交
455 456 457
  return ch;
}

458
}  // namespace distributed
G
gongweibao 已提交
459 460
}  // namespace operators
}  // namespace paddle