grpc_client.h 7.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <time.h>
Y
Yi Wang 已提交
18

W
Wu Yi 已提交
19 20
#include <chrono>              // NOLINT
#include <condition_variable>  // NOLINT
G
gongweibao 已提交
21 22 23 24
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
Y
Yancey1989 已提交
25
#include <mutex>  // NOLINT
G
gongweibao 已提交
26
#include <string>
W
Wu Yi 已提交
27
#include <thread>  // NOLINT
G
gongweibao 已提交
28 29
#include <vector>

W
Wu Yi 已提交
30
#include "grpc++/channel.h"
Y
Yi Wang 已提交
31 32 33 34 35
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
T
typhoonzero 已提交
36
#include "paddle/fluid/framework/blocking_queue.h"
Y
Yi Wang 已提交
37 38 39 40
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
41
#include "paddle/fluid/operators/distributed/request_handler.h"
42
#include "paddle/fluid/operators/distributed/rpc_client.h"
43 44
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
45
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
Y
Yancey1989 已提交
46
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
G
gongweibao 已提交
47 48 49

namespace paddle {
namespace operators {
50
namespace distributed {
G
gongweibao 已提交
51

52
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
G
gongweibao 已提交
53

54
class BaseProcessor {
G
gongweibao 已提交
55
 public:
T
typhoonzero 已提交
56 57 58
  explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
    context_ = nullptr;
  }
G
gongweibao 已提交
59

60
  virtual ~BaseProcessor() {}
G
gongweibao 已提交
61 62 63 64

  virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
    context_.reset(new grpc::ClientContext());
    var_h_ = var_info;
W
Wu Yi 已提交
65
    context_->set_wait_for_ready(true);
Y
Yancey1989 已提交
66 67 68 69 70 71
    if (time_out) {
      std::chrono::system_clock::time_point deadline =
          std::chrono::system_clock::now() +
          std::chrono::milliseconds(time_out);
      context_->set_deadline(deadline);
    }
G
gongweibao 已提交
72 73
  }

Y
Yancey 已提交
74 75
  virtual void Prepare(int64_t time_out) {
    context_.reset(new grpc::ClientContext());
W
Wu Yi 已提交
76
    context_->set_wait_for_ready(true);
Y
Yancey 已提交
77 78 79 80 81 82 83

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

G
gongweibao 已提交
84 85 86 87 88 89 90
  virtual void Process() = 0;

  std::unique_ptr<grpc::ClientContext> context_;
  grpc::Status status_;
  VarHandle var_h_;
};

91
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
92 93
    RequestSendCallBack;

94
class SendProcessor : public BaseProcessor {
G
gongweibao 已提交
95
 public:
96
  explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
97
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
98 99 100 101 102 103 104 105 106

  virtual ~SendProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

107 108
  ::grpc::GenericStub stub_g_;
  ::grpc::ByteBuffer reply_;
T
typhoonzero 已提交
109
  RequestSendCallBack response_call_back_ = nullptr;
G
gongweibao 已提交
110 111
};

112
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
113 114
    RequestGetCallBack;

115
class GetProcessor : public BaseProcessor {
G
gongweibao 已提交
116
 public:
117
  explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
118
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
119 120 121 122 123 124 125 126 127

  virtual ~GetProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

128 129
  ::grpc::ByteBuffer reply_;
  ::grpc::GenericStub stub_g_;
G
gongweibao 已提交
130 131 132
  RequestGetCallBack response_call_back_ = ProcGetResponse;
};

133
class BatchBarrierProcessor : public BaseProcessor {
Y
Yancey 已提交
134 135
 public:
  explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
136 137 138
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
Y
Yancey 已提交
139 140 141 142 143

  virtual ~BatchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VoidMessage reply_;
144
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Y
Yancey 已提交
145 146
};

147 148 149
class FetchBarrierProcessor : public BaseProcessor {
 public:
  explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
150 151 152
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
153 154 155 156 157

  virtual ~FetchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VariableMessage reply_;
158
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
159 160
};

T
tangwei12 已提交
161 162 163 164 165 166 167 168 169 170 171 172
class CheckpointNotifyProcessor : public BaseProcessor {
 public:
  explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }

  virtual ~CheckpointNotifyProcessor() {}

  virtual void Process() {}
  sendrecv::VoidMessage reply_;
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
T
bug fix  
tangwei12 已提交
173
};
T
tangwei12 已提交
174

G
gongweibao 已提交
175
class GRPCClient : public RPCClient {
G
gongweibao 已提交
176
 public:
M
minqiyang 已提交
177
  GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
G
gongweibao 已提交
178
  virtual ~GRPCClient();
Y
Yancey1989 已提交
179

G
gongweibao 已提交
180 181
  bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
                    const framework::Scope& scope, const std::string& var_name,
G
gongweibao 已提交
182
                    int64_t time_out = FLAGS_rpc_deadline) override;
Y
Yancey1989 已提交
183

G
gongweibao 已提交
184 185
  bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
                   const framework::Scope& scope, const std::string& var_name,
G
gongweibao 已提交
186
                   int64_t time_out = FLAGS_rpc_deadline) override;
G
gongweibao 已提交
187

G
gongweibao 已提交
188
  bool AsyncPrefetchVar(const std::string& ep,
G
gongweibao 已提交
189 190
                        const platform::DeviceContext& ctx,
                        const framework::Scope& scope,
G
gongweibao 已提交
191 192
                        const std::string& in_var_name,
                        const std::string& out_var_name,
G
gongweibao 已提交
193
                        int64_t time_out = FLAGS_rpc_deadline) override;
Y
Yancey 已提交
194

W
Wu Yi 已提交
195
  void AsyncSendBatchBarrier(const std::string& ep,
G
gongweibao 已提交
196
                             int64_t time_out = FLAGS_rpc_deadline) override;
Q
Qiao Longfei 已提交
197

W
Wu Yi 已提交
198
  void AsyncSendFetchBarrier(const std::string& ep,
G
gongweibao 已提交
199
                             int64_t time_out = FLAGS_rpc_deadline) override;
200

T
renae  
tangwei12 已提交
201
  void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
T
tangwei12 已提交
202
                             int64_t time_out = FLAGS_rpc_deadline) override;
T
tangwei12 已提交
203

Y
Yancey1989 已提交
204 205
  void AsyncSendComplete(const std::string& ep,
                         int64_t time_out = FLAGS_rpc_deadline) override;
Y
Yancey1989 已提交
206

Y
Yancey1989 已提交
207
  bool Wait() override;
Y
Yancey 已提交
208

Y
Yancey1989 已提交
209
  void SendComplete() override;
W
Wu Yi 已提交
210

G
gongweibao 已提交
211 212 213 214
 protected:
  void InitImpl() override;

 private:
W
Wu Yi 已提交
215 216
  // InitEventLoop should only be called by Init()
  void InitEventLoop();
G
gongweibao 已提交
217

W
Wu Yi 已提交
218
  void Proceed();
G
gongweibao 已提交
219

Y
Yancey1989 已提交
220
  std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
G
gongweibao 已提交
221 222 223

 private:
  grpc::CompletionQueue cq_;
W
Wu Yi 已提交
224 225 226 227 228 229
  std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
  std::unique_ptr<std::thread> client_thread_;

  // mutex for Wait client sync
  std::mutex sync_mutex_;
  std::condition_variable sync_cond_;
Y
Yancey1989 已提交
230
  std::atomic<int64_t> req_count_{0};
Y
Yancey1989 已提交
231
  bool ok_;
W
Wu Yi 已提交
232 233 234

  // mutex for GetChannel thread safety
  std::mutex chan_mutex_;
G
gongweibao 已提交
235
  DISABLE_COPY_AND_ASSIGN(GRPCClient);
Y
Yancey1989 已提交
236 237 238 239

  // mutex for sending complete message only once
  std::mutex completed_mutex_;
  bool completed_;
M
minqiyang 已提交
240

M
minqiyang 已提交
241
  volatile bool stopped_;
G
gongweibao 已提交
242 243
};

244
}  // namespace distributed
G
gongweibao 已提交
245 246
}  // namespace operators
}  // namespace paddle