grpc_client.cc 12.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <sys/time.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
20
#include "paddle/fluid/operators/distributed/grpc_client.h"
21
#include "paddle/fluid/operators/distributed/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28

G
gongweibao 已提交
29
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
32 33
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
34
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
35 36
}

Y
Yancey1989 已提交
37
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
38 39 40
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
41
      VLOG(30) << "send complete message to " << it.first;
Y
Yancey1989 已提交
42 43 44 45
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
46 47 48
  }
}

G
gongweibao 已提交
49
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
50
  stopped_ = true;
W
Wu Yi 已提交
51 52 53 54 55 56 57
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
58
    channels_.clear();
W
Wu Yi 已提交
59 60
  }
  client_thread_->join();
Y
Yancey1989 已提交
61 62
}

63 64 65 66 67
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
68 69 70 71
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
72
  const auto ch = GetChannel(ep_val);
73
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
74 75
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
76
  s->Prepare(h, time_out);
77

G
gongweibao 已提交
78
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
79
    auto* var = p_scope->FindVar(var_name_val);
80 81

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
82
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
83

84
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
85 86

    // stub context
T
typhoonzero 已提交
87
    s->response_call_back_ = nullptr;
88

G
gongweibao 已提交
89
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
90

91 92
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
93
    call->StartCall();
Y
Yi Wang 已提交
94
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
95 96 97 98

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
99
  });
G
gongweibao 已提交
100 101
  req_count_++;

102
  return h;
G
gongweibao 已提交
103 104 105
}

void ProcGetResponse(const VarHandle& var_h,
106
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
107
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
108 109 110 111
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
112 113 114 115 116
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
117
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
118 119
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
120 121
}

122 123 124 125 126
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
127 128 129 130
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
131
  const auto ch = GetChannel(ep_val);
132
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
133 134
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
135
  s->Prepare(h, time_out);
136

G
gongweibao 已提交
137
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
138
    // prepare input
139 140
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
141
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
142 143
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
144

145
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
146 147 148 149

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
150
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
151

152 153
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
154
    call->StartCall();
Y
Yi Wang 已提交
155
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
156 157 158 159

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
160
  });
G
gongweibao 已提交
161 162 163

  req_count_++;

164
  return h;
G
gongweibao 已提交
165 166
}

167 168 169 170 171
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
172
                                          const std::string& table_name,
173
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
174 175 176 177 178
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
179
  const auto ch = GetChannel(ep_val);
180
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
181 182 183 184

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
185
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
186

T
wip  
typhoonzero 已提交
187
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
188
                      s, method, h, table_name, this] {
Q
Qiao Longfei 已提交
189 190 191
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
192 193
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
                          0, table_name);
Q
Qiao Longfei 已提交
194

195
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
196 197 198 199

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
200
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
201

Q
Qiao Longfei 已提交
202
    auto call = s->stub_g_.PrepareUnaryCall(
203 204
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
205
    call->StartCall();
206
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
207 208 209 210

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
211 212 213
  });

  req_count_++;
214
  return h;
Q
Qiao Longfei 已提交
215 216
}

217 218
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
219
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
220 221

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
222 223 224
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
225
  s->Prepare(h, time_out);
Y
Yancey 已提交
226 227 228

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
229

G
gongweibao 已提交
230
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
231

Y
Yancey 已提交
232
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
233
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
234
  req_count_++;
G
gongweibao 已提交
235 236 237 238 239

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

240
  return h;
241
}
Y
Yancey 已提交
242

243 244
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
245
  const auto ch = GetChannel(ep);
246
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
247 248 249
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
250
  s->Prepare(h, time_out);
251 252 253

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
254

G
gongweibao 已提交
255
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
256

257
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
258
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
259
  req_count_++;
G
gongweibao 已提交
260 261 262 263 264

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

265
  return h;
Y
Yancey 已提交
266 267
}

268 269
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
270 271 272
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
273 274
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
275
  s->Prepare(h, time_out);
W
Wu Yi 已提交
276 277

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
278
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
279

G
gongweibao 已提交
280
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
281

W
Wu Yi 已提交
282 283
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
284
  req_count_++;
G
gongweibao 已提交
285 286 287 288 289

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

290
  return h;
Y
Yancey1989 已提交
291 292
}

293 294 295
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
296
  const auto ch = GetChannel(ep);
297

T
tangwei12 已提交
298
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
299 300 301 302 303

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
304
  s->Prepare(h, time_out);
T
tangwei12 已提交
305

306 307
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
308
  req.set_out_varname(dir);
T
tangwei12 已提交
309

G
gongweibao 已提交
310
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
311

T
bug fix  
tangwei12 已提交
312
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
313 314
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
315 316 317 318 319

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

320
  return h;
T
tangwei12 已提交
321 322
}

Y
Yancey1989 已提交
323
bool GRPCClient::Wait() {
W
Wu Yi 已提交
324
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
325 326
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
327 328
}

G
gongweibao 已提交
329
void GRPCClient::Proceed() {
W
Wu Yi 已提交
330
  void* tag = nullptr;
G
gongweibao 已提交
331 332
  bool ok = false;

333
  VLOG(30) << "GRPCClient Proceed begin";
M
minqiyang 已提交
334
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
335 336 337
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
338

W
Wu Yi 已提交
339
    if (c->status_.ok()) {
340
      VLOG(30) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
341
      c->Process();
Y
Yancey1989 已提交
342
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
G
gongweibao 已提交
343
      // FIXME(gongwb): parse error_details?
344
      LOG(ERROR) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
345 346 347
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
348 349 350 351
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
352
      c->Finish(false);
W
Wu Yi 已提交
353
    } else {
354
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
355 356 357 358
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

359
      c->Finish(false);
W
Wu Yi 已提交
360
    }
361

G
gongweibao 已提交
362
    bool notify = false;
W
Wu Yi 已提交
363 364 365
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
366 367 368 369 370 371 372
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
373
    }
G
gongweibao 已提交
374
  }
375
  VLOG(30) << "GRPCClient Proceed end";
G
gongweibao 已提交
376
}
W
Wu Yi 已提交
377

G
gongweibao 已提交
378
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
379
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
380
  auto it = channels_.find(ep);
G
gongweibao 已提交
381 382 383 384
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
385
  // Channel configurations:
G
gongweibao 已提交
386
  grpc::ChannelArguments args;
W
Wu Yi 已提交
387
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
388
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
389 390 391
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
392 393
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
394
  channels_[ep] = ch;
G
gongweibao 已提交
395 396 397
  return ch;
}

398
}  // namespace distributed
G
gongweibao 已提交
399 400
}  // namespace operators
}  // namespace paddle