grpc_client.cc 11.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/distributed/grpc_client.h"
Y
Yi Wang 已提交
16

17
#include <sys/time.h>
Y
Yi Wang 已提交
18 19 20

#include <limits>

G
gongweibao 已提交
21
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/threadpool.h"
23
#include "paddle/fluid/operators/distributed/grpc_serde.h"
24
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
25
#include "paddle/fluid/platform/profiler.h"
26

G
gongweibao 已提交
27 28
namespace paddle {
namespace operators {
29
namespace distributed {
G
gongweibao 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
32

G
gongweibao 已提交
33
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
34 35
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
36
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
37 38
}

Y
Yancey1989 已提交
39
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
40 41 42 43 44 45 46 47
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
      VLOG(3) << "send complete message to " << it.first;
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
48 49 50
  }
}

G
gongweibao 已提交
51
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
52
  stopped_ = true;
W
Wu Yi 已提交
53 54 55 56 57 58 59
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
60
    channels_.clear();
W
Wu Yi 已提交
61 62
  }
  client_thread_->join();
Y
Yancey1989 已提交
63 64
}

65 66 67 68 69
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
70 71 72 73
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
74
  const auto ch = GetChannel(ep_val);
75
  SendProcessor* s = new SendProcessor(ch);
G
gongweibao 已提交
76 77
  const std::string method = "SendRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
78
  s->Prepare(h, time_out);
79

G
gongweibao 已提交
80
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
81
    auto* var = p_scope->FindVar(var_name_val);
82 83 84

    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
85

86
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
87 88

    // stub context
T
typhoonzero 已提交
89
    s->response_call_back_ = nullptr;
90

G
gongweibao 已提交
91 92
    platform::RecordEvent record_event(method, p_ctx);

93 94
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
95
    call->StartCall();
Y
Yi Wang 已提交
96
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
97 98 99 100

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
101
  });
G
gongweibao 已提交
102 103
  req_count_++;

104
  return h;
G
gongweibao 已提交
105 106 107
}

void ProcGetResponse(const VarHandle& var_h,
108
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
109
  framework::Variable* outvar = nullptr;
110
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
111 112 113 114 115
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
116
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
117 118
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
119 120
}

121 122 123 124 125
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
126 127 128 129
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
130
  const auto ch = GetChannel(ep_val);
131
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
132 133
  const std::string method = "GetRPC";
  VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
134
  s->Prepare(h, time_out);
135

G
gongweibao 已提交
136
  framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
Q
Qiao Longfei 已提交
137
    // prepare input
138 139
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
Q
Qiao Longfei 已提交
140 141
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
142

143
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
144 145 146 147

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
148 149
    platform::RecordEvent record_event(method, p_ctx);

150 151
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
152
    call->StartCall();
Y
Yi Wang 已提交
153
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
G
gongweibao 已提交
154 155 156 157

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
158
  });
G
gongweibao 已提交
159 160 161

  req_count_++;

162
  return h;
G
gongweibao 已提交
163 164
}

165 166 167 168 169 170
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
171 172 173 174 175
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
176
  const auto ch = GetChannel(ep_val);
177
  GetProcessor* s = new GetProcessor(ch);
G
gongweibao 已提交
178 179 180 181

  const std::string method = "PrefetchRPC";

  VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
182
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
183

T
wip  
typhoonzero 已提交
184
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
G
gongweibao 已提交
185
                      s, method, h, this] {
Q
Qiao Longfei 已提交
186 187 188
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Y
Yancey1989 已提交
189
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
Q
Qiao Longfei 已提交
190

191
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
192 193 194 195

    // stub context
    s->response_call_back_ = ProcGetResponse;

G
gongweibao 已提交
196 197
    platform::RecordEvent record_event(method, p_ctx);

Q
Qiao Longfei 已提交
198
    auto call = s->stub_g_.PrepareUnaryCall(
199 200
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
201
    call->StartCall();
202
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
G
gongweibao 已提交
203 204 205 206

    if (UNLIKELY(platform::IsProfileEnabled())) {
      h->Wait();
    }
Q
Qiao Longfei 已提交
207 208 209
  });

  req_count_++;
210
  return h;
Q
Qiao Longfei 已提交
211 212
}

213 214
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
215
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
216 217

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
218 219 220
  const std::string method = "BatchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
221
  s->Prepare(h, time_out);
Y
Yancey 已提交
222 223 224

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
G
gongweibao 已提交
225 226 227

  platform::RecordEvent record_event(method, nullptr);

Y
Yancey 已提交
228
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
229
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
230
  req_count_++;
G
gongweibao 已提交
231 232 233 234 235

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

236
  return h;
237
}
Y
Yancey 已提交
238

239 240
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
241
  const auto ch = GetChannel(ep);
242
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
G
gongweibao 已提交
243 244 245
  const std::string method = "FetchBarrierRPC";
  VarHandlePtr h(
      new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
246
  s->Prepare(h, time_out);
247 248 249

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
G
gongweibao 已提交
250 251 252

  platform::RecordEvent record_event(method, nullptr);

253
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
254
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
255
  req_count_++;
G
gongweibao 已提交
256 257 258 259 260

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

261
  return h;
Y
Yancey 已提交
262 263
}

264 265
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
266 267 268
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
G
gongweibao 已提交
269 270
  const std::string method = "SendCompleteRPC";
  VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
271
  s->Prepare(h, time_out);
W
Wu Yi 已提交
272 273

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
274
  req.set_varname(COMPLETE_MESSAGE);
G
gongweibao 已提交
275 276 277

  platform::RecordEvent record_event(method, nullptr);

W
Wu Yi 已提交
278 279
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
280
  req_count_++;
G
gongweibao 已提交
281 282 283 284 285

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

286
  return h;
Y
Yancey1989 已提交
287 288
}

289 290 291
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
292
  const auto ch = GetChannel(ep);
293

T
tangwei12 已提交
294
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
G
gongweibao 已提交
295 296 297 298 299

  const std::string method = "CheckPointNotifyRPC";

  VarHandlePtr h(
      new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
300
  s->Prepare(h, time_out);
T
tangwei12 已提交
301

302 303
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
304
  req.set_out_varname(dir);
T
tangwei12 已提交
305

G
gongweibao 已提交
306 307
  platform::RecordEvent record_event(method, nullptr);

T
bug fix  
tangwei12 已提交
308
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
309 310
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
G
gongweibao 已提交
311 312 313 314 315

  if (UNLIKELY(platform::IsProfileEnabled())) {
    h->Wait();
  }

316
  return h;
T
tangwei12 已提交
317 318
}

Y
Yancey1989 已提交
319
bool GRPCClient::Wait() {
W
Wu Yi 已提交
320
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
321 322
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
323 324
}

G
gongweibao 已提交
325
void GRPCClient::Proceed() {
W
Wu Yi 已提交
326
  void* tag = nullptr;
G
gongweibao 已提交
327 328
  bool ok = false;

329
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
330
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
331 332 333
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
G
gongweibao 已提交
334

W
Wu Yi 已提交
335
    if (c->status_.ok()) {
336
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
337
      c->Process();
Y
Yancey1989 已提交
338
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
339
      LOG(ERROR) << c->GetVarHandlePtr()->String()
Y
Yancey1989 已提交
340 341 342 343 344
                 << " meets grpc error:" << c->status_.error_message();
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
345
      c->Finish(false);
W
Wu Yi 已提交
346
    } else {
347
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
348
                 << " meets grpc error:" << c->status_.error_message();
349
      c->Finish(false);
W
Wu Yi 已提交
350
    }
351

G
gongweibao 已提交
352
    bool notify = false;
W
Wu Yi 已提交
353 354 355
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
356 357 358 359 360 361 362
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
363
    }
G
gongweibao 已提交
364
  }
365
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
366
}
W
Wu Yi 已提交
367

G
gongweibao 已提交
368
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
369
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
370
  auto it = channels_.find(ep);
G
gongweibao 已提交
371 372 373 374
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
375
  // Channel configurations:
G
gongweibao 已提交
376
  grpc::ChannelArguments args;
W
Wu Yi 已提交
377
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
378
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
379 380 381
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
382 383
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
384
  channels_[ep] = ch;
G
gongweibao 已提交
385 386 387
  return ch;
}

388
}  // namespace distributed
G
gongweibao 已提交
389 390
}  // namespace operators
}  // namespace paddle