grpc_client.cc 15.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <stdlib.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
W
Wu Yi 已提交
20 21
#include "paddle/fluid/operators/distributed/grpc/grpc_client.h"
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
23
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/platform/profiler.h"
25

26 27
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
28 29
namespace paddle {
namespace operators {
30
namespace distributed {
G
gongweibao 已提交
31

32
void GRPCClient::InitImpl() {
W
Wu Yi 已提交
33 34
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
35 36
  PADDLE_ENFORCE(client_thread_ == nullptr,
                 "please not re init proceed thread");
G
gongweibao 已提交
37
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
38 39
}

Y
Yancey1989 已提交
40
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
41 42 43
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
44
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
45 46 47 48
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
49 50 51
  }
}

G
gongweibao 已提交
52
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
53
  stopped_ = true;
W
Wu Yi 已提交
54 55 56 57 58 59 60
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
61
    channels_.clear();
W
Wu Yi 已提交
62 63
  }
  client_thread_->join();
Y
Yancey1989 已提交
64 65
}

66 67 68 69 70
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
71 72 73 74
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
75
  const auto ch = GetChannel(ep_val);
76
  SendProcessor* s = new SendProcessor(ch);
77
  const std::string method = kSendRPC;
G
gongweibao 已提交
78
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
79
  s->Prepare(h, time_out);
80

G
gongweibao 已提交
81
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
82
    auto* var = p_scope->FindVar(var_name_val);
83 84

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
85
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
86

M
minqiyang 已提交
87
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
88 89

    // stub context
T
typhoonzero 已提交
90
    s->response_call_back_ = nullptr;
91

92
    platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
93

94 95
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
96
    call->StartCall();
Y
Yi Wang 已提交
97
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
98 99 100 101

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
102
  });
G
gongweibao 已提交
103 104
  req_count_++;

105
  return h;
G
gongweibao 已提交
106 107 108
}

void ProcGetResponse(const VarHandle& var_h,
109
                     const ::grpc::ByteBuffer& ret_msg) {
110
  VLOG(4) << "ProcGetResponse";
T
typhoonzero 已提交
111
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
112 113 114 115
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
116 117 118 119 120
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
121
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
122 123
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
124 125
}

126 127 128 129
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
130
                                     const std::string& out_varname,
131
                                     int64_t time_out) {
132
  return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
133 134 135
                      "/sendrecv.SendRecvService/GetVariable", time_out);
}

136 137 138 139 140 141 142 143 144 145 146 147
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    const std::string& out_varname, int64_t time_out) {
  std::string var_name_no_barrier =
      string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);

  return _AsyncGetVar(
      ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
      "/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
}

148 149 150 151
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& var_name,
    int64_t time_out) {
152
  return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
153 154 155
                      "/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}

156 157 158 159 160
VarHandlePtr GRPCClient::_AsyncGetVar(
    const std::string& ep, const platform::DeviceContext& ctx,
    const framework::Scope& scope, const std::string& method,
    const std::string& var_name, const std::string& out_varname,
    const std::string& rpc_path, int64_t time_out) {
161 162 163
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
164
  const std::string out_varname_val = out_varname;
165
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
166
  const auto ch = GetChannel(ep_val);
167
  GetProcessor* s = new GetProcessor(ch);
168 169

  VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
170
  s->Prepare(h, time_out);
171

172 173 174 175 176 177 178 179 180
  framework::AsyncIO(
      [var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
        // prepare input
        sendrecv::VariableMessage req;
        req.set_varname(var_name_val);
        req.set_out_varname(out_varname_val);
        req.set_trainer_id(trainer_id_);
        ::grpc::ByteBuffer buf;
        RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
181

182
        VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
183

184 185
        // stub context
        s->response_call_back_ = ProcGetResponse;
186

187
        platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
188

189 190 191 192
        auto call =
            s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
        call->StartCall();
        call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
193

194 195 196 197
        if (UNLIKELY(platform::IsProfileEnabled())) {
          h->Wait();
        }
      });
G
gongweibao 已提交
198 199 200

  req_count_++;

201
  return h;
G
gongweibao 已提交
202 203
}

204 205 206 207 208
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
209
                                          const std::string& table_name,
210
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
211 212 213 214
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
215
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
216
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
217
  const auto ch = GetChannel(ep_val);
218
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
219

220
  const std::string method = kPrefetchRPC;
G
gongweibao 已提交
221 222

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
223
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
224

T
wip  
typhoonzero 已提交
225
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
226
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
227 228 229
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
230
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
231
                          0, table_name_val);
Q
Qiao Longfei 已提交
232

M
minqiyang 已提交
233
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
234 235 236 237

    // stub context
    s->response_call_back_ = ProcGetResponse;

238
    platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
239

Q
Qiao Longfei 已提交
240
    auto call = s->stub_g_.PrepareUnaryCall(
241 242
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
243
    call->StartCall();
244
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
245 246 247 248

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
249 250 251
  });

  req_count_++;
252
  return h;
Q
Qiao Longfei 已提交
253 254
}

255 256
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
257
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
258 259

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
260
  const std::string method = kBatchBarrierRPC;
G
gongweibao 已提交
261 262
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
263
  s->Prepare(h, time_out);
Y
Yancey 已提交
264 265 266

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
267

268
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
269

Y
Yancey 已提交
270
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
271
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
272
  req_count_++;
G
gongweibao 已提交
273 274 275 276 277

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

278
  return h;
279
}
Y
Yancey 已提交
280

281 282
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
283
  const auto ch = GetChannel(ep);
284
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
285
  const std::string method = kFetchBarrierRPC;
G
gongweibao 已提交
286 287
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
288
  s->Prepare(h, time_out);
289 290 291

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
292

293
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
294

295
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
296
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
297
  req_count_++;
G
gongweibao 已提交
298 299 300 301 302

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

303
  return h;
Y
Yancey 已提交
304 305
}

306 307 308 309 310
VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
                                                const std::string& var_name,
                                                int64_t time_out) {
  const auto ch = GetChannel(ep);
  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
311
  const std::string method = kSendMonomerFetchBarrierRPC;
312
  VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
313 314 315 316 317 318 319
  s->Prepare(h, time_out);

  VLOG(30) << s->GetVarHandlePtr()->String() << " begin";

  sendrecv::VariableMessage req;
  req.set_varname(var_name);

320
  platform::RecordRPCEvent record_event(method);
321 322 323 324 325 326 327 328 329 330 331 332

  auto rpc = s->stub_->AsyncGetMonomerBarrier(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

  return h;
}

333 334
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
335 336 337
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
338
  const std::string method = kSendCompleteRPC;
G
gongweibao 已提交
339
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
340
  s->Prepare(h, time_out);
W
Wu Yi 已提交
341 342

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
343
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
344

345
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
346

W
Wu Yi 已提交
347 348
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
349
  req_count_++;
G
gongweibao 已提交
350 351 352 353 354

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

355
  return h;
Y
Yancey1989 已提交
356 357
}

358 359 360
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
361
  const auto ch = GetChannel(ep);
362

T
tangwei12 已提交
363
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
364

365
  const std::string method = kCheckPointNotifyRPC;
G
gongweibao 已提交
366 367 368

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
369
  s->Prepare(h, time_out);
T
tangwei12 已提交
370

371 372
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
373
  req.set_out_varname(dir);
T
tangwei12 已提交
374

375
  platform::RecordRPCEvent record_event(method);
G
gongweibao 已提交
376

T
bug fix  
tangwei12 已提交
377
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
378 379
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
380 381 382 383 384

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

385
  return h;
T
tangwei12 已提交
386 387
}

Y
Yancey1989 已提交
388
bool GRPCClient::Wait() {
W
Wu Yi 已提交
389
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
390 391
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
392 393
}

G
gongweibao 已提交
394
void GRPCClient::Proceed() {
W
Wu Yi 已提交
395
  void* tag = nullptr;
G
gongweibao 已提交
396 397
  bool ok = false;

M
minqiyang 已提交
398
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
399
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
400 401 402
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
403

W
Wu Yi 已提交
404
    if (c->status_.ok()) {
M
minqiyang 已提交
405
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
406
      c->Process();
Y
Yancey1989 已提交
407
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
408
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
409 410 411
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
412 413 414 415
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
416
      c->Finish(false);
W
Wu Yi 已提交
417
    } else {
418
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
419 420 421 422
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

423
      c->Finish(false);
W
Wu Yi 已提交
424
    }
425

G
gongweibao 已提交
426
    bool notify = false;
W
Wu Yi 已提交
427 428 429
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
430 431 432 433 434 435 436
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
437
    }
G
gongweibao 已提交
438
  }
439 440 441 442 443 444 445 446 447

  // Last log message
  // Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
  // static Mutex log_mutex is used for synchronization, which might have been
  // destructed at this moment.
  if (FLAGS_v >= 3) {
    std::string msg("GRPCClient Proceed end");
    fwrite(msg.c_str(), msg.length(), 1, stdout);
  }
G
gongweibao 已提交
448
}
W
Wu Yi 已提交
449

G
gongweibao 已提交
450
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
451
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
452
  auto it = channels_.find(ep);
G
gongweibao 已提交
453 454 455 456
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
457
  // Channel configurations:
G
gongweibao 已提交
458
  grpc::ChannelArguments args;
W
Wu Yi 已提交
459
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
460 461 462
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
463
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
464 465 466
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
467 468
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
469
  channels_[ep] = ch;
G
gongweibao 已提交
470 471 472
  return ch;
}

473
}  // namespace distributed
G
gongweibao 已提交
474 475
}  // namespace operators
}  // namespace paddle