grpc_client.cc 12.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include <limits>

G
gongweibao 已提交
17
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
18
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
19
#include "paddle/fluid/operators/distributed/grpc_client.h"
20
#include "paddle/fluid/operators/distributed/grpc_serde.h"
21
#include "paddle/fluid/operators/distributed/request_handler.h"
P
peizhilin 已提交
22
#include "paddle/fluid/platform/port.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

25 26
DECLARE_bool(rpc_disable_reuse_port);

G
gongweibao 已提交
27 28
namespace paddle {
namespace operators {
29
namespace distributed {
G
gongweibao 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
32

G
gongweibao 已提交
33
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
34 35
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
36
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
37 38
}

Y
Yancey1989 已提交
39
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
40 41 42
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
M
minqiyang 已提交
43
      VLOG(3) << "send complete message to " << it.first;
Y
Yancey1989 已提交
44 45 46 47
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
48 49 50
  }
}

G
gongweibao 已提交
51
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
52
  stopped_ = true;
W
Wu Yi 已提交
53 54 55 56 57 58 59
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
60
    channels_.clear();
W
Wu Yi 已提交
61 62
  }
  client_thread_->join();
Y
Yancey1989 已提交
63 64
}

65 66 67 68 69
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
70 71 72 73
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
74
  const auto ch = GetChannel(ep_val);
75
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
76 77
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
78
  s->Prepare(h, time_out);
79

G
gongweibao 已提交
80
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
81
    auto* var = p_scope->FindVar(var_name_val);
82 83

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
84
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
85

M
minqiyang 已提交
86
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
87 88

    // stub context
T
typhoonzero 已提交
89
    s->response_call_back_ = nullptr;
90

G
gongweibao 已提交
91
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
92

93 94
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
95
    call->StartCall();
Y
Yi Wang 已提交
96
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
97 98 99 100

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
101
  });
G
gongweibao 已提交
102 103
  req_count_++;

104
  return h;
G
gongweibao 已提交
105 106 107
}

void ProcGetResponse(const VarHandle& var_h,
108
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
109
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
110 111 112 113
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
114 115 116 117 118
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
119
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
120 121
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
122 123
}

124 125 126 127 128
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
129 130 131 132
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
133
  const auto ch = GetChannel(ep_val);
134
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
135 136
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
137
  s->Prepare(h, time_out);
138

G
gongweibao 已提交
139
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
140
    // prepare input
141 142
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
143
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
144 145
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
146

M
minqiyang 已提交
147
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
148 149 150 151

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
152
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
153

154 155
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
156
    call->StartCall();
Y
Yi Wang 已提交
157
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
158 159 160 161

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
162
  });
G
gongweibao 已提交
163 164 165

  req_count_++;

166
  return h;
G
gongweibao 已提交
167 168
}

169 170 171 172 173
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
Q
Qiao Longfei 已提交
174
                                          const std::string& table_name,
175
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
176 177 178 179
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
Q
Qiao Longfei 已提交
180
  const std::string table_name_val = table_name;
Q
Qiao Longfei 已提交
181
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
182
  const auto ch = GetChannel(ep_val);
183
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
184 185 186 187

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
188
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
189

T
wip  
typhoonzero 已提交
190
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
Q
Qiao Longfei 已提交
191
                      s, method, h, table_name_val, this] {
Q
Qiao Longfei 已提交
192 193 194
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Q
Qiao Longfei 已提交
195
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
Q
Qiao Longfei 已提交
196
                          0, table_name_val);
Q
Qiao Longfei 已提交
197

M
minqiyang 已提交
198
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
199 200 201 202

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
203
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
204

Q
Qiao Longfei 已提交
205
    auto call = s->stub_g_.PrepareUnaryCall(
206 207
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
208
    call->StartCall();
209
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
210 211 212 213

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
214 215 216
  });

  req_count_++;
217
  return h;
Q
Qiao Longfei 已提交
218 219
}

220 221
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
222
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
223 224

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
225 226 227
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
228
  s->Prepare(h, time_out);
Y
Yancey 已提交
229 230 231

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
232

G
gongweibao 已提交
233
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
234

Y
Yancey 已提交
235
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
236
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
237
  req_count_++;
G
gongweibao 已提交
238 239 240 241 242

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

243
  return h;
244
}
Y
Yancey 已提交
245

246 247
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
248
  const auto ch = GetChannel(ep);
249
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
250 251 252
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
253
  s->Prepare(h, time_out);
254 255 256

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
257

G
gongweibao 已提交
258
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
259

260
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
261
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
262
  req_count_++;
G
gongweibao 已提交
263 264 265 266 267

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

268
  return h;
Y
Yancey 已提交
269 270
}

271 272
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
273 274 275
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
276 277
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
278
  s->Prepare(h, time_out);
W
Wu Yi 已提交
279 280

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
281
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
282

G
gongweibao 已提交
283
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
284

W
Wu Yi 已提交
285 286
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
287
  req_count_++;
G
gongweibao 已提交
288 289 290 291 292

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

293
  return h;
Y
Yancey1989 已提交
294 295
}

296 297 298
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
299
  const auto ch = GetChannel(ep);
300

T
tangwei12 已提交
301
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
302 303 304 305 306

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
307
  s->Prepare(h, time_out);
T
tangwei12 已提交
308

309 310
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
311
  req.set_out_varname(dir);
T
tangwei12 已提交
312

G
gongweibao 已提交
313
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
314

T
bug fix  
tangwei12 已提交
315
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
316 317
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
318 319 320 321 322

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

323
  return h;
T
tangwei12 已提交
324 325
}

Y
Yancey1989 已提交
326
bool GRPCClient::Wait() {
W
Wu Yi 已提交
327
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
328 329
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
330 331
}

G
gongweibao 已提交
332
void GRPCClient::Proceed() {
W
Wu Yi 已提交
333
  void* tag = nullptr;
G
gongweibao 已提交
334 335
  bool ok = false;

M
minqiyang 已提交
336
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
337
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
338 339 340
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
341

W
Wu Yi 已提交
342
    if (c->status_.ok()) {
M
minqiyang 已提交
343
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
344
      c->Process();
Y
Yancey1989 已提交
345
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
G
gongweibao 已提交
346
      // FIXME(gongwb): parse error_details?
347
      LOG(ERROR) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
348 349 350
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
351 352 353 354
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
355
      c->Finish(false);
W
Wu Yi 已提交
356
    } else {
357
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
358 359 360 361
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

362
      c->Finish(false);
W
Wu Yi 已提交
363
    }
364

G
gongweibao 已提交
365
    bool notify = false;
W
Wu Yi 已提交
366 367 368
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
369 370 371 372 373 374 375
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
376
    }
G
gongweibao 已提交
377
  }
M
minqiyang 已提交
378
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
379
}
W
Wu Yi 已提交
380

G
gongweibao 已提交
381
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
382
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
383
  auto it = channels_.find(ep);
G
gongweibao 已提交
384 385 386 387
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
388
  // Channel configurations:
G
gongweibao 已提交
389
  grpc::ChannelArguments args;
W
Wu Yi 已提交
390
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
391 392 393
  if (FLAGS_rpc_disable_reuse_port) {
    args.SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
  }
394
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
395 396 397
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
398 399
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
400
  channels_[ep] = ch;
G
gongweibao 已提交
401 402 403
  return ch;
}

404
}  // namespace distributed
G
gongweibao 已提交
405 406
}  // namespace operators
}  // namespace paddle