grpc_client.cc 10.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/distributed/grpc_client.h"
Y
Yi Wang 已提交
16

17
#include <sys/time.h>
Y
Yi Wang 已提交
18 19 20

#include <limits>

G
gongweibao 已提交
21
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/threadpool.h"
23
#include "paddle/fluid/operators/distributed/grpc_serde.h"
24
#include "paddle/fluid/operators/distributed/request_handler.h"
X
Xin Pan 已提交
25
#include "paddle/fluid/platform/profiler.h"
26

G
gongweibao 已提交
27 28
namespace paddle {
namespace operators {
29
namespace distributed {
G
gongweibao 已提交
30

G
gongweibao 已提交
31
void GRPCClient::InitImpl() { InitEventLoop(); }
Y
Yancey1989 已提交
32

G
gongweibao 已提交
33
void GRPCClient::InitEventLoop() {
W
Wu Yi 已提交
34 35
  // start the client process thread
  // TODO(wuyi): can make this in a threadpool
G
gongweibao 已提交
36
  client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
W
Wu Yi 已提交
37 38
}

Y
Yancey1989 已提交
39
void GRPCClient::SendComplete() {
Y
Yancey1989 已提交
40 41 42 43 44 45 46 47
  std::unique_lock<std::mutex> lk(completed_mutex_);
  if (!completed_) {
    for (auto& it : channels_) {
      VLOG(3) << "send complete message to " << it.first;
      this->AsyncSendComplete(it.first);
    }
    PADDLE_ENFORCE(this->Wait(), "internal grpc error");
    completed_ = true;
W
Wu Yi 已提交
48 49 50
  }
}

G
gongweibao 已提交
51
GRPCClient::~GRPCClient() {
M
minqiyang 已提交
52
  stopped_ = true;
W
Wu Yi 已提交
53 54 55 56 57 58 59
  Wait();
  cq_.Shutdown();
  {
    std::lock_guard<std::mutex> guard(chan_mutex_);
    for (auto& it : channels_) {
      it.second.reset();
    }
M
minqiyang 已提交
60
    channels_.clear();
W
Wu Yi 已提交
61 62
  }
  client_thread_->join();
Y
Yancey1989 已提交
63 64
}

65 66 67 68 69
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
                                      const platform::DeviceContext& ctx,
                                      const framework::Scope& scope,
                                      const std::string& var_name,
                                      int64_t time_out) {
70 71 72 73
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
74
  const auto ch = GetChannel(ep_val);
75 76 77
  SendProcessor* s = new SendProcessor(ch);
  VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
  s->Prepare(h, time_out);
78

79
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
80
    auto* var = p_scope->FindVar(var_name_val);
81 82 83

    ::grpc::ByteBuffer req;
    SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
84

85
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
86 87

    // stub context
T
typhoonzero 已提交
88
    s->response_call_back_ = nullptr;
89

90 91
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
92
    call->StartCall();
Y
Yi Wang 已提交
93
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
94
  });
G
gongweibao 已提交
95 96
  req_count_++;

97
  return h;
G
gongweibao 已提交
98 99 100
}

void ProcGetResponse(const VarHandle& var_h,
101
                     const ::grpc::ByteBuffer& ret_msg) {
T
typhoonzero 已提交
102
  framework::Variable* outvar = nullptr;
103
  DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
104 105 106 107 108
}

template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
  ::grpc::Slice slice(proto.ByteSizeLong());
Q
qiaolongfei 已提交
109
  proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
110 111
  ::grpc::ByteBuffer tmp(&slice, 1);
  result->Swap(&tmp);
G
gongweibao 已提交
112 113
}

114 115 116 117 118
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
                                     const platform::DeviceContext& ctx,
                                     const framework::Scope& scope,
                                     const std::string& var_name,
                                     int64_t time_out) {
119 120 121 122
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string var_name_val = var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
123
  const auto ch = GetChannel(ep_val);
124 125 126
  GetProcessor* s = new GetProcessor(ch);
  VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
  s->Prepare(h, time_out);
127

128
  framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
Q
Qiao Longfei 已提交
129
    // prepare input
130 131
    sendrecv::VariableMessage req;
    req.set_varname(var_name_val);
Q
Qiao Longfei 已提交
132 133
    ::grpc::ByteBuffer buf;
    RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
134

135
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
136 137 138 139

    // stub context
    s->response_call_back_ = ProcGetResponse;

140 141
    auto call = s->stub_g_.PrepareUnaryCall(
        s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
142
    call->StartCall();
Y
Yi Wang 已提交
143
    call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
144
  });
G
gongweibao 已提交
145 146 147

  req_count_++;

148
  return h;
G
gongweibao 已提交
149 150
}

151 152 153 154 155 156
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
                                          const platform::DeviceContext& ctx,
                                          const framework::Scope& scope,
                                          const std::string& in_var_name,
                                          const std::string& out_var_name,
                                          int64_t time_out) {
Q
Qiao Longfei 已提交
157 158 159 160 161
  const platform::DeviceContext* p_ctx = &ctx;
  const std::string ep_val = ep;
  const std::string in_var_name_val = in_var_name;
  const std::string out_var_name_val = out_var_name;
  const framework::Scope* p_scope = &scope;
Y
Yancey1989 已提交
162
  const auto ch = GetChannel(ep_val);
163 164 165 166
  GetProcessor* s = new GetProcessor(ch);
  VarHandlePtr h(
      new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
  s->Prepare(h, time_out);
Q
Qiao Longfei 已提交
167

T
wip  
typhoonzero 已提交
168
  framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
169
                      time_out, s, this] {
Q
Qiao Longfei 已提交
170 171 172
    auto* var = p_scope->FindVar(in_var_name_val);

    ::grpc::ByteBuffer req;
Y
Yancey1989 已提交
173
    SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
Q
Qiao Longfei 已提交
174

175
    VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
Q
Qiao Longfei 已提交
176 177 178 179 180

    // stub context
    s->response_call_back_ = ProcGetResponse;

    auto call = s->stub_g_.PrepareUnaryCall(
181 182
        s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
        &cq_);
Q
Qiao Longfei 已提交
183
    call->StartCall();
184
    call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
Q
Qiao Longfei 已提交
185 186 187
  });

  req_count_++;
188
  return h;
Q
Qiao Longfei 已提交
189 190
}

191 192
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
193
  const auto ch = GetChannel(ep);
Y
Yancey 已提交
194 195

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
196 197 198
  VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
                               nullptr, nullptr));
  s->Prepare(h, time_out);
Y
Yancey 已提交
199 200 201 202

  sendrecv::VariableMessage req;
  req.set_varname(BATCH_BARRIER_MESSAGE);
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
203
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey 已提交
204
  req_count_++;
205
  return h;
206
}
Y
Yancey 已提交
207

208 209
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
                                               int64_t time_out) {
Y
Yancey1989 已提交
210
  const auto ch = GetChannel(ep);
211
  FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
212 213 214
  VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
                               nullptr, nullptr));
  s->Prepare(h, time_out);
215 216 217 218

  sendrecv::VariableMessage req;
  req.set_varname(FETCH_BARRIER_MESSAGE);
  auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
Y
Yi Wang 已提交
219
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
220
  req_count_++;
221
  return h;
Y
Yancey 已提交
222 223
}

224 225
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
                                           int64_t time_out) {
W
Wu Yi 已提交
226 227 228
  const auto ch = GetChannel(ep);

  BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
229 230 231
  VarHandlePtr h(
      new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
  s->Prepare(h, time_out);
W
Wu Yi 已提交
232 233

  sendrecv::VariableMessage req;
Y
Yancey1989 已提交
234
  req.set_varname(COMPLETE_MESSAGE);
W
Wu Yi 已提交
235 236
  auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
Y
Yancey1989 已提交
237
  req_count_++;
238
  return h;
Y
Yancey1989 已提交
239 240
}

241 242 243
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
                                               const std::string& dir,
                                               int64_t time_out) {
T
tangwei12 已提交
244
  const auto ch = GetChannel(ep);
245

T
tangwei12 已提交
246
  CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
247 248 249
  VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
                               nullptr, nullptr));
  s->Prepare(h, time_out);
T
tangwei12 已提交
250

251 252
  sendrecv::VariableMessage req;
  req.set_varname(CHECKPOINT_SAVE_MESSAGE);
253
  req.set_out_varname(dir);
T
tangwei12 已提交
254

T
bug fix  
tangwei12 已提交
255
  auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
T
tangwei12 已提交
256 257
  rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
  req_count_++;
258
  return h;
T
tangwei12 已提交
259 260
}

Y
Yancey1989 已提交
261
bool GRPCClient::Wait() {
W
Wu Yi 已提交
262
  std::unique_lock<std::mutex> lk(sync_mutex_);
Y
Yancey1989 已提交
263 264
  sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
  return ok_;
G
gongweibao 已提交
265 266
}

G
gongweibao 已提交
267
void GRPCClient::Proceed() {
W
Wu Yi 已提交
268
  void* tag = nullptr;
G
gongweibao 已提交
269 270
  bool ok = false;

271
  VLOG(3) << "GRPCClient Proceed begin";
M
minqiyang 已提交
272
  while (!stopped_ && cq_.Next(&tag, &ok)) {
W
Wu Yi 已提交
273 274 275 276
    BaseProcessor* c = static_cast<BaseProcessor*>(tag);
    GPR_ASSERT(ok);
    PADDLE_ENFORCE(c);
    if (c->status_.ok()) {
277
      VLOG(3) << c->GetVarHandlePtr()->String() << " process";
W
Wu Yi 已提交
278
      c->Process();
Y
Yancey1989 已提交
279
    } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
280
      LOG(ERROR) << c->GetVarHandlePtr()->String()
Y
Yancey1989 已提交
281 282 283 284 285
                 << " meets grpc error:" << c->status_.error_message();
      {
        std::lock_guard<std::mutex> lk(sync_mutex_);
        ok_ = false;
      }
286
      c->Finish(false);
W
Wu Yi 已提交
287
    } else {
288
      LOG(FATAL) << c->GetVarHandlePtr()->String()
G
gongweibao 已提交
289
                 << " meets grpc error:" << c->status_.error_message();
290
      c->Finish(false);
W
Wu Yi 已提交
291
    }
292

G
gongweibao 已提交
293
    bool notify = false;
W
Wu Yi 已提交
294 295 296
    {
      std::lock_guard<std::mutex> lk(sync_mutex_);
      req_count_--;
G
gongweibao 已提交
297 298 299 300 301 302 303
      notify = (req_count_ <= 0 || !c->status_.ok());
    }

    delete c;

    if (notify) {
      sync_cond_.notify_all();
W
Wu Yi 已提交
304
    }
G
gongweibao 已提交
305
  }
306
  VLOG(3) << "GRPCClient Proceed end";
G
gongweibao 已提交
307
}
W
Wu Yi 已提交
308

G
gongweibao 已提交
309
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
W
Wu Yi 已提交
310
  std::lock_guard<std::mutex> guard(chan_mutex_);
Y
Yancey1989 已提交
311
  auto it = channels_.find(ep);
G
gongweibao 已提交
312 313 314 315
  if (it != channels_.end()) {
    return it->second;
  }

W
Wu Yi 已提交
316
  // Channel configurations:
G
gongweibao 已提交
317
  grpc::ChannelArguments args;
W
Wu Yi 已提交
318
  args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
319
  args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
G
gongweibao 已提交
320 321 322
  args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
  args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());

T
typhoonzero 已提交
323 324
  auto ch =
      grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
Y
Yancey1989 已提交
325
  channels_[ep] = ch;
G
gongweibao 已提交
326 327 328
  return ch;
}

329
}  // namespace distributed
G
gongweibao 已提交
330 331
}  // namespace operators
}  // namespace paddle