grpc_client.h 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <time.h>
Y
Yi Wang 已提交
18

W
Wu Yi 已提交
19 20
#include <chrono>              // NOLINT
#include <condition_variable>  // NOLINT
G
gongweibao 已提交
21 22 23 24
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
Y
Yancey1989 已提交
25
#include <mutex>  // NOLINT
G
gongweibao 已提交
26
#include <string>
W
Wu Yi 已提交
27
#include <thread>  // NOLINT
G
gongweibao 已提交
28 29
#include <vector>

W
Wu Yi 已提交
30
#include "grpc++/channel.h"
Y
Yi Wang 已提交
31 32 33 34 35
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
T
typhoonzero 已提交
36
#include "paddle/fluid/framework/blocking_queue.h"
Y
Yi Wang 已提交
37 38 39 40
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
41 42
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
Y
Yancey1989 已提交
43
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
G
gongweibao 已提交
44 45 46

namespace paddle {
namespace operators {
47
namespace distributed {
G
gongweibao 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61

struct VarHandle {
  std::string ep;
  const platform::DeviceContext* ctx;
  const framework::Scope* scope;
  std::string name;

  std::string String() const {
    std::ostringstream s;
    s << "name:[" << name << "] ep:[" << ep << "]";
    return s.str();
  }
};

62
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
G
gongweibao 已提交
63

64
class BaseProcessor {
G
gongweibao 已提交
65
 public:
T
typhoonzero 已提交
66 67 68
  explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
    context_ = nullptr;
  }
G
gongweibao 已提交
69

70
  virtual ~BaseProcessor() {}
G
gongweibao 已提交
71 72 73 74 75 76 77 78 79 80 81

  virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
    context_.reset(new grpc::ClientContext());
    var_h_ = var_info;

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

Y
Yancey 已提交
82 83 84 85 86 87 88 89 90
  virtual void Prepare(int64_t time_out) {
    context_.reset(new grpc::ClientContext());

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

G
gongweibao 已提交
91 92 93 94 95 96 97
  virtual void Process() = 0;

  std::unique_ptr<grpc::ClientContext> context_;
  grpc::Status status_;
  VarHandle var_h_;
};

98
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
99 100
    RequestSendCallBack;

101
class SendProcessor : public BaseProcessor {
G
gongweibao 已提交
102
 public:
103
  explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
104
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
105 106 107 108 109 110 111 112 113

  virtual ~SendProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

114 115
  ::grpc::GenericStub stub_g_;
  ::grpc::ByteBuffer reply_;
T
typhoonzero 已提交
116
  RequestSendCallBack response_call_back_ = nullptr;
G
gongweibao 已提交
117 118
};

119
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
120 121
    RequestGetCallBack;

122
class GetProcessor : public BaseProcessor {
G
gongweibao 已提交
123
 public:
124
  explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
125
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
126 127 128 129 130 131 132 133 134

  virtual ~GetProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

135 136
  ::grpc::ByteBuffer reply_;
  ::grpc::GenericStub stub_g_;
G
gongweibao 已提交
137 138 139
  RequestGetCallBack response_call_back_ = ProcGetResponse;
};

140
class BatchBarrierProcessor : public BaseProcessor {
Y
Yancey 已提交
141 142
 public:
  explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
143 144 145
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
Y
Yancey 已提交
146 147 148 149 150

  virtual ~BatchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VoidMessage reply_;
151
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Y
Yancey 已提交
152 153
};

154 155 156
class FetchBarrierProcessor : public BaseProcessor {
 public:
  explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
157 158 159
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
160 161 162 163 164

  virtual ~FetchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VariableMessage reply_;
165
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
166 167
};

G
gongweibao 已提交
168
class GRPCClient : public RPCClient {
G
gongweibao 已提交
169
 public:
G
gongweibao 已提交
170 171
  GRPCClient() {}
  virtual ~GRPCClient();
Y
Yancey1989 已提交
172

G
gongweibao 已提交
173 174 175
  bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
                    const framework::Scope& scope, const std::string& var_name,
                    int64_t time_out = RPCClient::rpc_time_out) override;
Y
Yancey1989 已提交
176

G
gongweibao 已提交
177 178 179
  bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
                   const framework::Scope& scope, const std::string& var_name,
                   int64_t time_out = RPCClient::rpc_time_out) override;
G
gongweibao 已提交
180

G
gongweibao 已提交
181
  bool AsyncPrefetchVar(const std::string& ep,
G
gongweibao 已提交
182 183
                        const platform::DeviceContext& ctx,
                        const framework::Scope& scope,
G
gongweibao 已提交
184 185 186
                        const std::string& in_var_name,
                        const std::string& out_var_name,
                        int64_t time_out = RPCClient::rpc_time_out) override;
Y
Yancey 已提交
187

G
gongweibao 已提交
188 189 190
  void AsyncSendBatchBarrier(
      const std::string& ep,
      int64_t time_out = RPCClient::rpc_time_out) override;
Q
Qiao Longfei 已提交
191

G
gongweibao 已提交
192 193 194
  void AsyncSendFetchBarrier(
      const std::string& ep,
      int64_t time_out = RPCClient::rpc_time_out) override;
195

G
gongweibao 已提交
196
  void Wait() override;
Y
Yancey 已提交
197

W
Wu Yi 已提交
198 199
  void SendComplete() override;

G
gongweibao 已提交
200 201 202 203
 protected:
  void InitImpl() override;

 private:
W
Wu Yi 已提交
204 205
  // InitEventLoop should only be called by Init()
  void InitEventLoop();
G
gongweibao 已提交
206

W
Wu Yi 已提交
207
  void Proceed();
G
gongweibao 已提交
208

W
Wu Yi 已提交
209 210 211
  void AsyncSendComplete(const std::string& ep,
                         int64_t time_out = RPCClient::rpc_time_out);

Y
Yancey1989 已提交
212
  std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
G
gongweibao 已提交
213 214 215

 private:
  grpc::CompletionQueue cq_;
W
Wu Yi 已提交
216 217 218 219 220 221
  std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
  std::unique_ptr<std::thread> client_thread_;

  // mutex for Wait client sync
  std::mutex sync_mutex_;
  std::condition_variable sync_cond_;
Y
Yancey1989 已提交
222
  std::atomic<int64_t> req_count_{0};
W
Wu Yi 已提交
223 224 225

  // mutex for GetChannel thread safety
  std::mutex chan_mutex_;
G
gongweibao 已提交
226
  DISABLE_COPY_AND_ASSIGN(GRPCClient);
G
gongweibao 已提交
227 228
};

229
}  // namespace distributed
G
gongweibao 已提交
230 231
}  // namespace operators
}  // namespace paddle