You need to sign in or sign up before continuing.
grpc_client.cc 12.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <sys/time.h>
Y
Yi Wang 已提交
16 17
#include <limits>

G
gongweibao 已提交
18
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/threadpool.h"
G
gongweibao 已提交
20
#include "paddle/fluid/operators/distributed/grpc_client.h"
21
#include "paddle/fluid/operators/distributed/grpc_serde.h"
22
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
23
#include "paddle/fluid/platform/profiler.h"
24

G
gongweibao 已提交
25 26
namespace paddle {
namespace operators {
27
namespace distributed {
G
gongweibao 已提交
28

G
gongweibao 已提交
29
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
32 33
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
34
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
35 36
}

Y
Yancey1989 已提交
37
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
38 39 40 41 42 43 44 45
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
      VLOG(3) << "send complete message to " << it.first;
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
46 47 48
  }
}

G
gongweibao 已提交
49
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
50
  stopped_ = true;
W
Wu Yi 已提交
51 52 53 54 55 56 57
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
58
    channels_.clear();
W
Wu Yi 已提交
59 60
  }
  client_thread_->join();
Y
Yancey1989 已提交
61 62
}

63 64 65 66 67
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
68 69 70 71
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
72
  const auto ch = GetChannel(ep_val);
73
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
74 75
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
76
  s->Prepare(h, time_out);
77

G
gongweibao 已提交
78
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
79
    auto* var = p_scope->FindVar(var_name_val);
80 81

    ::grpc::ByteBuffer req;
W
Wu Yi 已提交
82
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
83

84
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
85 86

    // stub context
T
typhoonzero 已提交
87
    s->response_call_back_ = nullptr;
88

G
gongweibao 已提交
89
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
90

91 92
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
93
    call->StartCall();
Y
Yi Wang 已提交
94
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
95 96 97 98

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
99
  });
G
gongweibao 已提交
100 101
  req_count_++;

102
  return h;
G
gongweibao 已提交
103 104 105
}

void ProcGetResponse(const VarHandle& var_h,
106
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
107
  framework::Variable* outvar = nullptr;
W
Wu Yi 已提交
108 109 110 111
  // get response's trainer_id is not used
  int trainer_id;
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
                            &trainer_id);
112 113 114 115 116
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
117
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
118 119
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
120 121
}

122 123 124 125 126
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
127 128 129 130
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
131
  const auto ch = GetChannel(ep_val);
132
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
133 134
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
135
  s->Prepare(h, time_out);
136

G
gongweibao 已提交
137
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
138
    // prepare input
139 140
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
W
Wu Yi 已提交
141
    req.set_trainer_id(trainer_id_);
Q
Qiao Longfei 已提交
142 143
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
144

145
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
146 147 148 149

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
150
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
151

152 153
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
154
    call->StartCall();
Y
Yi Wang 已提交
155
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
156 157 158 159

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
160
  });
G
gongweibao 已提交
161 162 163

  req_count_++;

164
  return h;
G
gongweibao 已提交
165 166
}

167 168 169 170 171 172
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
173 174 175 176 177
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
178
  const auto ch = GetChannel(ep_val);
179
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
180 181 182 183

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
184
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
185

T
wip  
typhoonzero 已提交
186
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
G
gongweibao 已提交
187
                      s, method, h, this] {
Q
Qiao Longfei 已提交
188 189 190
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Y
Yancey1989 已提交
191
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
Q
Qiao Longfei 已提交
192

193
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
194 195 196 197

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
198
    platform::RecordRPCEvent record_event(method, p_ctx);
G
gongweibao 已提交
199

Q
Qiao Longfei 已提交
200
    auto call = s->stub_g_.PrepareUnaryCall(
201 202
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
203
    call->StartCall();
204
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
205 206 207 208

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
209 210 211
  });

  req_count_++;
212
  return h;
Q
Qiao Longfei 已提交
213 214
}

215 216
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
217
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
218 219

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
220 221 222
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
223
  s->Prepare(h, time_out);
Y
Yancey 已提交
224 225 226

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
227

G
gongweibao 已提交
228
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
229

Y
Yancey 已提交
230
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
231
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
232
  req_count_++;
G
gongweibao 已提交
233 234 235 236 237

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

238
  return h;
239
}
Y
Yancey 已提交
240

241 242
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
243
  const auto ch = GetChannel(ep);
244
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
245 246 247
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
248
  s->Prepare(h, time_out);
249 250 251

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
252

G
gongweibao 已提交
253
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
254

255
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
256
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
257
  req_count_++;
G
gongweibao 已提交
258 259 260 261 262

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

263
  return h;
Y
Yancey 已提交
264 265
}

266 267
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
268 269 270
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
271 272
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
273
  s->Prepare(h, time_out);
W
Wu Yi 已提交
274 275

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
276
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
277

G
gongweibao 已提交
278
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
279

W
Wu Yi 已提交
280 281
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
282
  req_count_++;
G
gongweibao 已提交
283 284 285 286 287

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

288
  return h;
Y
Yancey1989 已提交
289 290
}

291 292 293
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
294
  const auto ch = GetChannel(ep);
295

T
tangwei12 已提交
296
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
297 298 299 300 301

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
302
  s->Prepare(h, time_out);
T
tangwei12 已提交
303

304 305
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
306
  req.set_out_varname(dir);
T
tangwei12 已提交
307

G
gongweibao 已提交
308
  platform::RecordRPCEvent record_event(method, nullptr);
G
gongweibao 已提交
309

T
bug fix  
tangwei12 已提交
310
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
311 312
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
313 314 315 316 317

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

318
  return h;
T
tangwei12 已提交
319 320
}

Y
Yancey1989 已提交
321
bool GRPCClient::Wait() {
W
Wu Yi 已提交
322
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
323 324
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
325 326
}

G
gongweibao 已提交
327
void GRPCClient::Proceed() {
W
Wu Yi 已提交
328
  void* tag = nullptr;
G
gongweibao 已提交
329 330
  bool ok = false;

331
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
332
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
333 334 335
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
336

W
Wu Yi 已提交
337
    if (c->status_.ok()) {
338
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
339
      c->Process();
Y
Yancey1989 已提交
340
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
G
gongweibao 已提交
341
      // FIXME(gongwb): parse error_details?
342
      LOG(ERROR) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
343 344 345
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();
Y
Yancey1989 已提交
346 347 348 349
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
350
      c->Finish(false);
W
Wu Yi 已提交
351
    } else {
352
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
353 354 355 356
                 << " meets grpc error, error_code:" << c->status_.error_code()
                 << " error_message:" << c->status_.error_message()
                 << " error_details:" << c->status_.error_details();

357
      c->Finish(false);
W
Wu Yi 已提交
358
    }
359

G
gongweibao 已提交
360
    bool notify = false;
W
Wu Yi 已提交
361 362 363
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
364 365 366 367 368 369 370
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
371
    }
G
gongweibao 已提交
372
  }
373
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
374
}
W
Wu Yi 已提交
375

G
gongweibao 已提交
376
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
377
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
378
  auto it = channels_.find(ep);
G
gongweibao 已提交
379 380 381 382
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
383
  // Channel configurations:
G
gongweibao 已提交
384
  grpc::ChannelArguments args;
W
Wu Yi 已提交
385
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
386
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
387 388 389
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
390 391
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
392
  channels_[ep] = ch;
G
gongweibao 已提交
393 394 395
  return ch;
}

396
}  // namespace distributed
G
gongweibao 已提交
397 398
}  // namespace operators
}  // namespace paddle