grpc_client.h 6.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <time.h>
Y
Yi Wang 已提交
18

W
Wu Yi 已提交
19 20
#include <chrono>              // NOLINT
#include <condition_variable>  // NOLINT
G
gongweibao 已提交
21 22 23 24
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
Y
Yancey1989 已提交
25
#include <mutex>  // NOLINT
G
gongweibao 已提交
26
#include <string>
W
Wu Yi 已提交
27
#include <thread>  // NOLINT
G
gongweibao 已提交
28 29
#include <vector>

W
Wu Yi 已提交
30
#include "grpc++/channel.h"
Y
Yi Wang 已提交
31 32 33 34 35
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
T
typhoonzero 已提交
36
#include "paddle/fluid/framework/blocking_queue.h"
Y
Yi Wang 已提交
37 38 39 40
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
G
gongweibao 已提交
41
#include "paddle/fluid/operators/detail/rpc_client.h"
Y
Yi Wang 已提交
42
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
Y
Yancey1989 已提交
43
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
G
gongweibao 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

namespace paddle {
namespace operators {
namespace detail {

struct VarHandle {
  std::string ep;
  const platform::DeviceContext* ctx;
  const framework::Scope* scope;
  std::string name;

  std::string String() const {
    std::ostringstream s;
    s << "name:[" << name << "] ep:[" << ep << "]";
    return s.str();
  }
};

62
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
G
gongweibao 已提交
63

64
class BaseProcessor {
G
gongweibao 已提交
65
 public:
T
typhoonzero 已提交
66 67 68
  explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
    context_ = nullptr;
  }
G
gongweibao 已提交
69

70
  virtual ~BaseProcessor() {}
G
gongweibao 已提交
71 72 73 74 75 76 77 78 79 80 81

  virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
    context_.reset(new grpc::ClientContext());
    var_h_ = var_info;

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

Y
Yancey 已提交
82 83 84 85 86 87 88 89 90
  virtual void Prepare(int64_t time_out) {
    context_.reset(new grpc::ClientContext());

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

G
gongweibao 已提交
91 92 93 94 95 96 97
  virtual void Process() = 0;

  std::unique_ptr<grpc::ClientContext> context_;
  grpc::Status status_;
  VarHandle var_h_;
};

98
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
99 100
    RequestSendCallBack;

101
class SendProcessor : public BaseProcessor {
G
gongweibao 已提交
102
 public:
103
  explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
104
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
105 106 107 108 109 110 111 112 113

  virtual ~SendProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

114 115
  ::grpc::GenericStub stub_g_;
  ::grpc::ByteBuffer reply_;
T
typhoonzero 已提交
116
  RequestSendCallBack response_call_back_ = nullptr;
G
gongweibao 已提交
117 118
};

119
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
120 121
    RequestGetCallBack;

122
class GetProcessor : public BaseProcessor {
G
gongweibao 已提交
123
 public:
124
  explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
125
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
126 127 128 129 130 131 132 133 134

  virtual ~GetProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

135 136
  ::grpc::ByteBuffer reply_;
  ::grpc::GenericStub stub_g_;
G
gongweibao 已提交
137 138 139
  RequestGetCallBack response_call_back_ = ProcGetResponse;
};

140
class BatchBarrierProcessor : public BaseProcessor {
Y
Yancey 已提交
141 142
 public:
  explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
143 144 145
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
Y
Yancey 已提交
146 147 148 149 150

  virtual ~BatchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VoidMessage reply_;
151
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Y
Yancey 已提交
152 153
};

154 155 156
class FetchBarrierProcessor : public BaseProcessor {
 public:
  explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
157 158 159
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
160 161 162 163 164

  virtual ~FetchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VariableMessage reply_;
165
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
166 167
};

T
tangwei12 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181
class CheckpointNotifyProcessor : public BaseProcessor {
 public:
  explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }

  virtual ~CheckpointNotifyProcessor() {}

  virtual void Process() {}
  sendrecv::VoidMessage reply_;
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}

G
gongweibao 已提交
182
class GRPCClient : public RPCClient {
G
gongweibao 已提交
183
 public:
G
gongweibao 已提交
184 185
  GRPCClient() {}
  virtual ~GRPCClient();
Y
Yancey1989 已提交
186

G
gongweibao 已提交
187 188 189
  bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
                    const framework::Scope& scope, const std::string& var_name,
                    int64_t time_out = RPCClient::rpc_time_out) override;
Y
Yancey1989 已提交
190

G
gongweibao 已提交
191 192 193
  bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
                   const framework::Scope& scope, const std::string& var_name,
                   int64_t time_out = RPCClient::rpc_time_out) override;
G
gongweibao 已提交
194

G
gongweibao 已提交
195
  bool AsyncPrefetchVar(const std::string& ep,
G
gongweibao 已提交
196 197
                        const platform::DeviceContext& ctx,
                        const framework::Scope& scope,
G
gongweibao 已提交
198 199 200
                        const std::string& in_var_name,
                        const std::string& out_var_name,
                        int64_t time_out = RPCClient::rpc_time_out) override;
Y
Yancey 已提交
201

G
gongweibao 已提交
202 203 204
  void AsyncSendBatchBarrier(
      const std::string& ep,
      int64_t time_out = RPCClient::rpc_time_out) override;
Q
Qiao Longfei 已提交
205

G
gongweibao 已提交
206 207 208
  void AsyncSendFetchBarrier(
      const std::string& ep,
      int64_t time_out = RPCClient::rpc_time_out) override;
209

T
tangwei12 已提交
210 211 212 213
  void AsyncCheckpointNotify(
      const std::string& ep, const std::string& dir,
      int64_t time_out = RPCClient::rpc_time_out) override;

G
gongweibao 已提交
214
  void Wait() override;
Y
Yancey 已提交
215

W
Wu Yi 已提交
216 217
  void SendComplete() override;

G
gongweibao 已提交
218 219 220 221
 protected:
  void InitImpl() override;

 private:
W
Wu Yi 已提交
222 223
  // InitEventLoop should only be called by Init()
  void InitEventLoop();
G
gongweibao 已提交
224

W
Wu Yi 已提交
225
  void Proceed();
G
gongweibao 已提交
226

W
Wu Yi 已提交
227 228 229
  void AsyncSendComplete(const std::string& ep,
                         int64_t time_out = RPCClient::rpc_time_out);

Y
Yancey1989 已提交
230
  std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
G
gongweibao 已提交
231 232 233

 private:
  grpc::CompletionQueue cq_;
W
Wu Yi 已提交
234 235 236 237 238 239
  std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
  std::unique_ptr<std::thread> client_thread_;

  // mutex for Wait client sync
  std::mutex sync_mutex_;
  std::condition_variable sync_cond_;
Y
Yancey1989 已提交
240
  std::atomic<int64_t> req_count_{0};
W
Wu Yi 已提交
241 242 243

  // mutex for GetChannel thread safety
  std::mutex chan_mutex_;
G
gongweibao 已提交
244
  DISABLE_COPY_AND_ASSIGN(GRPCClient);
G
gongweibao 已提交
245 246 247 248 249
};

}  // namespace detail
}  // namespace operators
}  // namespace paddle