grpc_client.cc 12.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <sys/time.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
20
#include "paddle/fluid/operators/distributed/grpc_client.h"
21
#include "paddle/fluid/operators/distributed/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

25 26
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
27 28
namespace paddle {
namespace operators {
29
namespace distributed {
G
gongweibao 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
32

G
gongweibao 已提交
33
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
34 35
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
36
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
37 38
}

Y
Yancey1989 已提交
39
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
40 41 42
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
43
      VLOG(30) << "send complete message to " << it.first;
Y
Yancey1989 已提交
44 45 46 47
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
48 49 50
  }
}

G
gongweibao 已提交
51
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
52
  stopped_ = true;
W
Wu Yi 已提交
53 54 55 56 57 58 59
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
60
    channels_.clear();
W
Wu Yi 已提交
61 62
  }
  client_thread_->join();
Y
Yancey1989 已提交
63 64
}

65 66 67 68 69
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
70 71 72 73
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
74
  const auto ch = GetChannel(ep_val);
75
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
76 77
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
78
  s->Prepare(h, time_out);
79

G
gongweibao 已提交
80
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
81
    auto* var = p_scope->FindVar(var_name_val);
82 83

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
84
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
85

86
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
87 88

    // stub context
T
typhoonzero 已提交
89
    s->response_call_back_ = nullptr;
90

G
gongweibao 已提交
91
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
92

93 94
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
95
    call->StartCall();
Y
Yi Wang 已提交
96
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
97 98 99 100

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
101
  });
G
gongweibao 已提交
102 103
  req_count_++;

104
  return h;
G
gongweibao 已提交
105 106 107
}

void ProcGetResponse(const VarHandle& var_h,
108
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
109
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
110 111 112 113
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
114 115 116 117 118
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
119
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
120 121
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
122 123
}

124 125 126 127 128
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
129 130 131 132
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
133
  const auto ch = GetChannel(ep_val);
134
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
135 136
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
137
  s->Prepare(h, time_out);
138

G
gongweibao 已提交
139
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
140
    // prepare input
141 142
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
143
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
144 145
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
146

147
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
148 149 150 151

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
152
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
153

154 155
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
156
    call->StartCall();
Y
Yi Wang 已提交
157
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
158 159 160 161

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
162
  });
G
gongweibao 已提交
163 164 165

  req_count_++;

166
  return h;
G
gongweibao 已提交
167 168
}

169 170 171 172 173 174
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
175 176 177 178 179
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
180
  const auto ch = GetChannel(ep_val);
181
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
182 183 184 185

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
186
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
187

T
wip  
typhoonzero 已提交
188
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
G
gongweibao 已提交
189
                      s, method, h, this] {
Q
Qiao Longfei 已提交
190 191 192
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Y
Yancey1989 已提交
193
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
Q
Qiao Longfei 已提交
194

195
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
196 197 198 199

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
200
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
201

Q
Qiao Longfei 已提交
202
    auto call = s->stub_g_.PrepareUnaryCall(
203 204
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
205
    call->StartCall();
206
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
207 208 209 210

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
211 212 213
  });

  req_count_++;
214
  return h;
Q
Qiao Longfei 已提交
215 216
}

217 218
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
219
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
220 221

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
222 223 224
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
225
  s->Prepare(h, time_out);
Y
Yancey 已提交
226 227 228

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
229

G
gongweibao 已提交
230
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
231

Y
Yancey 已提交
232
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
233
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
234
  req_count_++;
G
gongweibao 已提交
235 236 237 238 239

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

240
  return h;
241
}
Y
Yancey 已提交
242

243 244
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
245
  const auto ch = GetChannel(ep);
246
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
247 248 249
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
250
  s->Prepare(h, time_out);
251 252 253

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
254

G
gongweibao 已提交
255
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
256

257
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
258
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
259
  req_count_++;
G
gongweibao 已提交
260 261 262 263 264

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

265
  return h;
Y
Yancey 已提交
266 267
}

268 269
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
270 271 272
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
273 274
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
275
  s->Prepare(h, time_out);
W
Wu Yi 已提交
276 277

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
278
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
279

G
gongweibao 已提交
280
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
281

W
Wu Yi 已提交
282 283
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
284
  req_count_++;
G
gongweibao 已提交
285 286 287 288 289

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

290
  return h;
Y
Yancey1989 已提交
291 292
}

293 294 295
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
296
  const auto ch = GetChannel(ep);
297

T
tangwei12 已提交
298
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
299 300 301 302 303

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
304
  s->Prepare(h, time_out);
T
tangwei12 已提交
305

306 307
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
308
  req.set_out_varname(dir);
T
tangwei12 已提交
309

G
gongweibao 已提交
310
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
311

T
bug fix  
tangwei12 已提交
312
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
313 314
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
315 316 317 318 319

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

320
  return h;
T
tangwei12 已提交
321 322
}

Y
Yancey1989 已提交
323
bool GRPCClient::Wait() {
W
Wu Yi 已提交
324
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
325 326
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
327 328
}

G
gongweibao 已提交
329
void GRPCClient::Proceed() {
W
Wu Yi 已提交
330
  void* tag = nullptr;
G
gongweibao 已提交
331 332
  bool ok = false;

333
  VLOG(30) << "GRPCClient Proceed begin";
M
minqiyang 已提交
334
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
335 336 337
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
338

W
Wu Yi 已提交
339
    if (c->status_.ok()) {
340
      VLOG(30) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
341
      c->Process();
Y
Yancey1989 已提交
342
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
G
gongweibao 已提交
343
      // FIXME(gongwb): parse error_details?
344
      LOG(ERROR) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
345 346 347
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
348 349 350 351
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
352
      c->Finish(false);
W
Wu Yi 已提交
353
    } else {
354
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
355 356 357 358
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

359
      c->Finish(false);
W
Wu Yi 已提交
360
    }
361

G
gongweibao 已提交
362
    bool notify = false;
W
Wu Yi 已提交
363 364 365
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
366 367 368 369 370 371 372
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
373
    }
G
gongweibao 已提交
374
  }
375
  VLOG(30) << "GRPCClient Proceed end";
G
gongweibao 已提交
376
}
W
Wu Yi 已提交
377

G
gongweibao 已提交
378
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
379
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
380
  auto it = channels_.find(ep);
G
gongweibao 已提交
381 382 383 384
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
385
  // Channel configurations:
G
gongweibao 已提交
386
  grpc::ChannelArguments args;
W
Wu Yi 已提交
387
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
388 389 390
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
391
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
392 393 394
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
395 396
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
397
  channels_[ep] = ch;
G
gongweibao 已提交
398 399 400
  return ch;
}

401
}  // namespace distributed
G
gongweibao 已提交
402 403
}  // namespace operators
}  // namespace paddle