grpc_client.cc 12.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <sys/time.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
20
#include "paddle/fluid/operators/distributed/grpc_client.h"
21
#include "paddle/fluid/operators/distributed/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28

G
gongweibao 已提交
29
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
32 33
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
34
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
35 36
}

Y
Yancey1989 已提交
37
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
38 39 40
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
41
      VLOG(30) << "send complete message to " << it.first;
Y
Yancey1989 已提交
42 43 44 45
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
46 47 48
  }
}

G
gongweibao 已提交
49
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
50
  stopped_ = true;
W
Wu Yi 已提交
51 52 53 54 55 56 57
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
58
    channels_.clear();
W
Wu Yi 已提交
59 60
  }
  client_thread_->join();
Y
Yancey1989 已提交
61 62
}

63 64 65 66 67
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
68 69 70 71
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
72
  const auto ch = GetChannel(ep_val);
73
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
74 75
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
76
  s->Prepare(h, time_out);
77

G
gongweibao 已提交
78
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
79
    auto* var = p_scope->FindVar(var_name_val);
80 81

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
82
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
83

84
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
85 86

    // stub context
T
typhoonzero 已提交
87
    s->response_call_back_ = nullptr;
88

G
gongweibao 已提交
89
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
90

91 92
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
93
    call->StartCall();
Y
Yi Wang 已提交
94
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
95 96 97 98

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
99
  });
G
gongweibao 已提交
100 101
  req_count_++;

102
  return h;
G
gongweibao 已提交
103 104 105
}

void ProcGetResponse(const VarHandle& var_h,
106
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
107
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
108 109 110 111
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
112 113 114 115 116
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
117
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
118 119
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
120 121
}

122 123 124 125 126
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
127 128 129 130
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
131
  const auto ch = GetChannel(ep_val);
132
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
133 134
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
135
  s->Prepare(h, time_out);
136

G
gongweibao 已提交
137
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
138
    // prepare input
139 140
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
141
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
142 143
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
144

145
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
146 147 148 149

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
150
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
151

152 153
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
154
    call->StartCall();
Y
Yi Wang 已提交
155
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
156 157 158 159

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
160
  });
G
gongweibao 已提交
161 162 163

  req_count_++;

164
  return h;
G
gongweibao 已提交
165 166
}

167 168 169 170 171
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
172
                                          const std::string& table_name,
173
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
174 175 176 177
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
178
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
179
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
180
  const auto ch = GetChannel(ep_val);
181
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
182 183 184 185

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
186
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
187

T
wip  
typhoonzero 已提交
188
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
189
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
190 191 192
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
193
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
194
                          0, table_name_val);
Q
Qiao Longfei 已提交
195

196
    VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
197 198 199 200

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
201
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
202

Q
Qiao Longfei 已提交
203
    auto call = s->stub_g_.PrepareUnaryCall(
204 205
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
206
    call->StartCall();
207
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
208 209 210 211

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
212 213 214
  });

  req_count_++;
215
  return h;
Q
Qiao Longfei 已提交
216 217
}

218 219
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
220
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
221 222

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
223 224 225
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
226
  s->Prepare(h, time_out);
Y
Yancey 已提交
227 228 229

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
230

G
gongweibao 已提交
231
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
232

Y
Yancey 已提交
233
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
234
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
235
  req_count_++;
G
gongweibao 已提交
236 237 238 239 240

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

241
  return h;
242
}
Y
Yancey 已提交
243

244 245
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
246
  const auto ch = GetChannel(ep);
247
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
248 249 250
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
251
  s->Prepare(h, time_out);
252 253 254

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
255

G
gongweibao 已提交
256
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
257

258
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
259
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
260
  req_count_++;
G
gongweibao 已提交
261 262 263 264 265

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

266
  return h;
Y
Yancey 已提交
267 268
}

269 270
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
271 272 273
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
274 275
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
276
  s->Prepare(h, time_out);
W
Wu Yi 已提交
277 278

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
279
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
280

G
gongweibao 已提交
281
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
282

W
Wu Yi 已提交
283 284
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
285
  req_count_++;
G
gongweibao 已提交
286 287 288 289 290

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

291
  return h;
Y
Yancey1989 已提交
292 293
}

294 295 296
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
297
  const auto ch = GetChannel(ep);
298

T
tangwei12 已提交
299
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
300 301 302 303 304

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
305
  s->Prepare(h, time_out);
T
tangwei12 已提交
306

307 308
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
309
  req.set_out_varname(dir);
T
tangwei12 已提交
310

G
gongweibao 已提交
311
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
312

T
bug fix  
tangwei12 已提交
313
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
314 315
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
316 317 318 319 320

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

321
  return h;
T
tangwei12 已提交
322 323
}

Y
Yancey1989 已提交
324
bool GRPCClient::Wait() {
W
Wu Yi 已提交
325
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
326 327
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
328 329
}

G
gongweibao 已提交
330
void GRPCClient::Proceed() {
W
Wu Yi 已提交
331
  void* tag = nullptr;
G
gongweibao 已提交
332 333
  bool ok = false;

334
  VLOG(30) << "GRPCClient Proceed begin";
M
minqiyang 已提交
335
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
336 337 338
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
339

W
Wu Yi 已提交
340
    if (c->status_.ok()) {
341
      VLOG(30) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
342
      c->Process();
Y
Yancey1989 已提交
343
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
G
gongweibao 已提交
344
      // FIXME(gongwb): parse error_details?
345
      LOG(ERROR) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
346 347 348
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
349 350 351 352
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
353
      c->Finish(false);
W
Wu Yi 已提交
354
    } else {
355
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
356 357 358 359
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

360
      c->Finish(false);
W
Wu Yi 已提交
361
    }
362

G
gongweibao 已提交
363
    bool notify = false;
W
Wu Yi 已提交
364 365 366
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
367 368 369 370 371 372 373
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
374
    }
G
gongweibao 已提交
375
  }
376
  VLOG(30) << "GRPCClient Proceed end";
G
gongweibao 已提交
377
}
W
Wu Yi 已提交
378

G
gongweibao 已提交
379
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
380
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
381
  auto it = channels_.find(ep);
G
gongweibao 已提交
382 383 384 385
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
386
  // Channel configurations:
G
gongweibao 已提交
387
  grpc::ChannelArguments args;
W
Wu Yi 已提交
388
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
389
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
390 391 392
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
393 394
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
395
  channels_[ep] = ch;
G
gongweibao 已提交
396 397 398
  return ch;
}

399
}  // namespace distributed
G
gongweibao 已提交
400 401
}  // namespace operators
}  // namespace paddle