grpc_client.h 9.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <time.h>
X
Xin Pan 已提交
18
#include <atomic>
W
Wu Yi 已提交
19 20
#include <chrono>              // NOLINT
#include <condition_variable>  // NOLINT
G
gongweibao 已提交
21 22 23 24
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
Q
Qiao Longfei 已提交
25
#include <memory>
Y
Yancey1989 已提交
26
#include <mutex>  // NOLINT
G
gongweibao 已提交
27
#include <string>
W
Wu Yi 已提交
28
#include <thread>  // NOLINT
Q
Qiao Longfei 已提交
29
#include <unordered_map>
G
gongweibao 已提交
30 31
#include <vector>

W
Wu Yi 已提交
32
#include "grpc++/channel.h"
Y
Yi Wang 已提交
33 34 35 36 37
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
T
typhoonzero 已提交
38
#include "paddle/fluid/framework/blocking_queue.h"
Y
Yi Wang 已提交
39 40 41 42
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
W
Wu Yi 已提交
43
#include "paddle/fluid/operators/distributed/distributed_pb.h"
44
#include "paddle/fluid/operators/distributed/request_handler.h"
45 46
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
Y
Yancey1989 已提交
47
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
G
gongweibao 已提交
48

W
wanghuancoder 已提交
49 50 51 52 53 54 55 56 57 58 59 60
namespace grpc {
class Channel;
}  // namespace grpc
namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
namespace platform {
class DeviceContext;
}  // namespace platform
}  // namespace paddle

G
gongweibao 已提交
61 62
namespace paddle {
namespace operators {
63
namespace distributed {
G
gongweibao 已提交
64

65
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
G
gongweibao 已提交
66

67 68
void ProcGetRecvResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);

69
class BaseProcessor {
G
gongweibao 已提交
70
 public:
71
  BaseProcessor() { context_ = nullptr; }
G
gongweibao 已提交
72

73
  virtual ~BaseProcessor() {}
G
gongweibao 已提交
74

75 76 77
  virtual void Prepare(VarHandlePtr h, int64_t time_out) {
    var_h_ = h;

G
gongweibao 已提交
78
    context_.reset(new grpc::ClientContext());
W
Wu Yi 已提交
79
    context_->set_wait_for_ready(true);
Y
Yancey1989 已提交
80 81 82 83 84 85
    if (time_out) {
      std::chrono::system_clock::time_point deadline =
          std::chrono::system_clock::now() +
          std::chrono::milliseconds(time_out);
      context_->set_deadline(deadline);
    }
G
gongweibao 已提交
86 87
  }

88 89 90
  void Process() {
    ProcessImpl();
    var_h_->Finish(true);
Y
Yancey 已提交
91 92
  }

93 94 95 96
  VarHandlePtr GetVarHandlePtr() { return var_h_; }
  bool Wait() { return var_h_->Wait(); }
  void Finish(bool ok) { return var_h_->Finish(ok); }
  virtual void ProcessImpl() = 0;
G
gongweibao 已提交
97 98 99

  std::unique_ptr<grpc::ClientContext> context_;
  grpc::Status status_;
100 101 102

 protected:
  VarHandlePtr var_h_;
G
gongweibao 已提交
103 104
};

105
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
106 107
    RequestSendCallBack;

108
class SendProcessor : public BaseProcessor {
G
gongweibao 已提交
109
 public:
110
  explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
111
      : BaseProcessor(), stub_g_(ch) {}
G
gongweibao 已提交
112 113 114

  virtual ~SendProcessor() {}

115
  void ProcessImpl() override {
G
gongweibao 已提交
116
    if (response_call_back_) {
117
      response_call_back_(*var_h_.get(), reply_);
G
gongweibao 已提交
118 119 120
    }
  }

121 122
  ::grpc::GenericStub stub_g_;
  ::grpc::ByteBuffer reply_;
T
typhoonzero 已提交
123
  RequestSendCallBack response_call_back_ = nullptr;
G
gongweibao 已提交
124 125
};

126
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
127 128
    RequestGetCallBack;

129
class GetProcessor : public BaseProcessor {
G
gongweibao 已提交
130
 public:
131
  explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
132
      : BaseProcessor(), stub_g_(ch) {}
G
gongweibao 已提交
133 134 135

  virtual ~GetProcessor() {}

136
  void ProcessImpl() override {
G
gongweibao 已提交
137
    if (response_call_back_) {
138
      response_call_back_(*var_h_.get(), reply_);
G
gongweibao 已提交
139 140 141
    }
  }

142 143
  ::grpc::ByteBuffer reply_;
  ::grpc::GenericStub stub_g_;
G
gongweibao 已提交
144 145 146
  RequestGetCallBack response_call_back_ = ProcGetResponse;
};

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
class SendAndRecvProcessor : public BaseProcessor {
 public:
  explicit SendAndRecvProcessor(std::shared_ptr<grpc::Channel> ch)
      : BaseProcessor(), stub_g_(ch) {}

  virtual ~SendAndRecvProcessor() {}

  void ProcessImpl() override {
    if (response_call_back_) {
      response_call_back_(*var_h_recv_.get(), reply_);
      var_h_recv_->Finish(true);
    }
  }

  void RecvPrepare(VarHandlePtr h_recv) { var_h_recv_ = h_recv; }

  ::grpc::ByteBuffer reply_;
  ::grpc::GenericStub stub_g_;
  RequestGetCallBack response_call_back_ = ProcGetResponse;
  VarHandlePtr var_h_recv_;
};

169
class BatchBarrierProcessor : public BaseProcessor {
Y
Yancey 已提交
170 171
 public:
  explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
172
      : BaseProcessor() {
173 174
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
Y
Yancey 已提交
175 176 177

  virtual ~BatchBarrierProcessor() {}

178
  void ProcessImpl() override {}
Y
Yancey 已提交
179
  sendrecv::VoidMessage reply_;
180
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Y
Yancey 已提交
181 182
};

183 184 185
class FetchBarrierProcessor : public BaseProcessor {
 public:
  explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
186
      : BaseProcessor() {
187 188
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
189 190 191

  virtual ~FetchBarrierProcessor() {}

192
  void ProcessImpl() override {}
193
  sendrecv::VariableMessage reply_;
194
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
195 196
};

T
tangwei12 已提交
197 198 199
class CheckpointNotifyProcessor : public BaseProcessor {
 public:
  explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
200
      : BaseProcessor() {
T
tangwei12 已提交
201 202 203 204 205
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }

  virtual ~CheckpointNotifyProcessor() {}

206
  void ProcessImpl() override {}
T
tangwei12 已提交
207 208
  sendrecv::VoidMessage reply_;
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
T
bug fix  
tangwei12 已提交
209
};
T
tangwei12 已提交
210

G
gongweibao 已提交
211
class GRPCClient : public RPCClient {
G
gongweibao 已提交
212
 public:
M
minqiyang 已提交
213
  GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
G
gongweibao 已提交
214
  virtual ~GRPCClient();
Y
Yancey1989 已提交
215

216 217 218 219 220 221 222 223 224 225
  VarHandlePtr AsyncSendVar(const std::string& ep,
                            const platform::DeviceContext& ctx,
                            const framework::Scope& scope,
                            const std::string& var_name,
                            int64_t time_out = FLAGS_rpc_deadline) override;

  VarHandlePtr AsyncGetVar(const std::string& ep,
                           const platform::DeviceContext& ctx,
                           const framework::Scope& scope,
                           const std::string& var_name,
226
                           const std::string& out_varname,
Q
Qiao Longfei 已提交
227
                           const std::string& table_name = "",
228 229
                           int64_t time_out = FLAGS_rpc_deadline) override;

230 231 232 233 234 235
  VarHandlePtr AsyncGetVarNoBarrier(
      const std::string& ep, const platform::DeviceContext& ctx,
      const framework::Scope& scope, const std::string& var_name,
      const std::string& out_varname,
      int64_t time_out = FLAGS_rpc_deadline) override;

236 237 238 239 240
  VarHandlePtr AsyncGetMonomerVariable(
      const std::string& ep, const platform::DeviceContext& ctx,
      const framework::Scope& scope, const std::string& var_name,
      int64_t time_out = FLAGS_rpc_deadline) override;

241 242 243 244 245
  VarHandlePtr AsyncPrefetchVar(const std::string& ep,
                                const platform::DeviceContext& ctx,
                                const framework::Scope& scope,
                                const std::string& in_var_name,
                                const std::string& out_var_name,
Q
Qiao Longfei 已提交
246
                                const std::string& table_name = "",
247 248 249 250 251
                                int64_t time_out = FLAGS_rpc_deadline) override;

  VarHandlePtr AsyncSendBatchBarrier(
      const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;

252 253 254 255 256 257
  VarHandlePtr AsyncSendFetchBarrier(const std::string& ep,
                                     int64_t time_out) override;

  VarHandlePtr AsyncGetMonomerBarrier(
      const std::string& ep, const std::string& var_name,
      int64_t time_out = FLAGS_rpc_deadline) override;
258 259

  VarHandlePtr AsyncCheckpointNotify(
260
      const std::string& ep, const std::string& dirname,
261
      const std::string& varname, const int mode,
262 263
      int64_t time_out = FLAGS_rpc_deadline) override;

264
  VarHandlePtr AsyncDistributeNotify(
1
123malin 已提交
265 266
      const std::string& ep, const platform::DeviceContext& ctx,
      const framework::Scope& scope, const std::string& var_name,
267 268
      int64_t time_out = FLAGS_rpc_deadline) override;

269 270 271 272 273 274 275 276
  VarHandlePtr AsyncSendAndRecv(const std::string& ep,
                                const platform::DeviceContext& ctx,
                                const framework::Scope& scope,
                                const std::string& send_var_name,
                                const std::string& recv_var_name,
                                const std::string& table_name = "",
                                int64_t time_out = FLAGS_rpc_deadline) override;

277 278
  VarHandlePtr AsyncSendComplete(
      const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
Y
Yancey1989 已提交
279

Y
Yancey1989 已提交
280
  bool Wait() override;
Y
Yancey 已提交
281

Y
Yancey1989 已提交
282
  void SendComplete() override;
W
Wu Yi 已提交
283

G
gongweibao 已提交
284 285 286
  void InitImpl() override;

 private:
W
Wu Yi 已提交
287
  void Proceed();
G
gongweibao 已提交
288

Y
Yancey1989 已提交
289
  std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
290 291 292 293
  VarHandlePtr _AsyncGetVar(
      const std::string& ep, const platform::DeviceContext& ctx,
      const framework::Scope& scope, const std::string& method,
      const std::string& var_name, const std::string& out_varname,
Q
Qiao Longfei 已提交
294 295
      const std::string& rpc_path, const std::string& table_name = "",
      int64_t time_out = FLAGS_rpc_deadline);
G
gongweibao 已提交
296 297 298

 private:
  grpc::CompletionQueue cq_;
W
Wu Yi 已提交
299
  std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
300
  std::unique_ptr<std::thread> client_thread_{nullptr};
W
Wu Yi 已提交
301 302 303 304

  // mutex for Wait client sync
  std::mutex sync_mutex_;
  std::condition_variable sync_cond_;
Y
Yancey1989 已提交
305
  std::atomic<int64_t> req_count_{0};
Y
Yancey1989 已提交
306
  bool ok_;
W
Wu Yi 已提交
307 308 309

  // mutex for GetChannel thread safety
  std::mutex chan_mutex_;
G
gongweibao 已提交
310
  DISABLE_COPY_AND_ASSIGN(GRPCClient);
Y
Yancey1989 已提交
311 312 313 314

  // mutex for sending complete message only once
  std::mutex completed_mutex_;
  bool completed_;
M
minqiyang 已提交
315

M
minqiyang 已提交
316
  volatile bool stopped_;
G
gongweibao 已提交
317 318
};

319
}  // namespace distributed
G
gongweibao 已提交
320 321
}  // namespace operators
}  // namespace paddle