grpc_client.h 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <time.h>
Y
Yi Wang 已提交
18

W
Wu Yi 已提交
19 20
#include <chrono>              // NOLINT
#include <condition_variable>  // NOLINT
G
gongweibao 已提交
21 22 23 24
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
Y
Yancey1989 已提交
25
#include <mutex>  // NOLINT
G
gongweibao 已提交
26
#include <string>
W
Wu Yi 已提交
27
#include <thread>  // NOLINT
G
gongweibao 已提交
28 29
#include <vector>

W
Wu Yi 已提交
30
#include "grpc++/channel.h"
Y
Yi Wang 已提交
31 32 33 34 35
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
#include "grpc++/support/slice.h"
#include "grpc/support/log.h"
T
typhoonzero 已提交
36
#include "paddle/fluid/framework/blocking_queue.h"
Y
Yi Wang 已提交
37 38 39 40
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
41 42
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
Y
Yancey1989 已提交
43
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
G
gongweibao 已提交
44 45 46

namespace paddle {
namespace operators {
47
namespace distributed {
G
gongweibao 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61

struct VarHandle {
  std::string ep;
  const platform::DeviceContext* ctx;
  const framework::Scope* scope;
  std::string name;

  std::string String() const {
    std::ostringstream s;
    s << "name:[" << name << "] ep:[" << ep << "]";
    return s.str();
  }
};

62
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
G
gongweibao 已提交
63

64
class BaseProcessor {
G
gongweibao 已提交
65
 public:
T
typhoonzero 已提交
66 67 68
  explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
    context_ = nullptr;
  }
G
gongweibao 已提交
69

70
  virtual ~BaseProcessor() {}
G
gongweibao 已提交
71 72 73 74

  virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
    context_.reset(new grpc::ClientContext());
    var_h_ = var_info;
W
Wu Yi 已提交
75
    context_->set_wait_for_ready(true);
G
gongweibao 已提交
76 77 78 79 80 81 82

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

Y
Yancey 已提交
83 84
  virtual void Prepare(int64_t time_out) {
    context_.reset(new grpc::ClientContext());
W
Wu Yi 已提交
85
    context_->set_wait_for_ready(true);
Y
Yancey 已提交
86 87 88 89 90 91 92

    std::chrono::system_clock::time_point deadline =
        std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

    context_->set_deadline(deadline);
  }

G
gongweibao 已提交
93 94 95 96 97 98 99
  virtual void Process() = 0;

  std::unique_ptr<grpc::ClientContext> context_;
  grpc::Status status_;
  VarHandle var_h_;
};

100
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
101 102
    RequestSendCallBack;

103
class SendProcessor : public BaseProcessor {
G
gongweibao 已提交
104
 public:
105
  explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
106
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
107 108 109 110 111 112 113 114 115

  virtual ~SendProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

116 117
  ::grpc::GenericStub stub_g_;
  ::grpc::ByteBuffer reply_;
T
typhoonzero 已提交
118
  RequestSendCallBack response_call_back_ = nullptr;
G
gongweibao 已提交
119 120
};

121
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
G
gongweibao 已提交
122 123
    RequestGetCallBack;

124
class GetProcessor : public BaseProcessor {
G
gongweibao 已提交
125
 public:
126
  explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
127
      : BaseProcessor(ch), stub_g_(ch) {}
G
gongweibao 已提交
128 129 130 131 132 133 134 135 136

  virtual ~GetProcessor() {}

  virtual void Process() {
    if (response_call_back_) {
      response_call_back_(var_h_, reply_);
    }
  }

137 138
  ::grpc::ByteBuffer reply_;
  ::grpc::GenericStub stub_g_;
G
gongweibao 已提交
139 140 141
  RequestGetCallBack response_call_back_ = ProcGetResponse;
};

142
class BatchBarrierProcessor : public BaseProcessor {
Y
Yancey 已提交
143 144
 public:
  explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
145 146 147
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
Y
Yancey 已提交
148 149 150 151 152

  virtual ~BatchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VoidMessage reply_;
153
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Y
Yancey 已提交
154 155
};

156 157 158
class FetchBarrierProcessor : public BaseProcessor {
 public:
  explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
159 160 161
      : BaseProcessor(ch) {
    stub_ = sendrecv::SendRecvService::NewStub(ch);
  }
162 163 164 165 166

  virtual ~FetchBarrierProcessor() {}

  virtual void Process() {}
  sendrecv::VariableMessage reply_;
167
  std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
168 169
};

G
gongweibao 已提交
170
class GRPCClient : public RPCClient {
G
gongweibao 已提交
171
 public:
G
gongweibao 已提交
172 173
  GRPCClient() {}
  virtual ~GRPCClient();
Y
Yancey1989 已提交
174

G
gongweibao 已提交
175 176
  bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
                    const framework::Scope& scope, const std::string& var_name,
W
Wu Yi 已提交
177
                    int64_t time_out = FLAGS_grpc_deadline) override;
Y
Yancey1989 已提交
178

G
gongweibao 已提交
179 180
  bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
                   const framework::Scope& scope, const std::string& var_name,
W
Wu Yi 已提交
181
                   int64_t time_out = FLAGS_grpc_deadline) override;
G
gongweibao 已提交
182

G
gongweibao 已提交
183
  bool AsyncPrefetchVar(const std::string& ep,
G
gongweibao 已提交
184 185
                        const platform::DeviceContext& ctx,
                        const framework::Scope& scope,
G
gongweibao 已提交
186 187
                        const std::string& in_var_name,
                        const std::string& out_var_name,
W
Wu Yi 已提交
188
                        int64_t time_out = FLAGS_grpc_deadline) override;
Y
Yancey 已提交
189

W
Wu Yi 已提交
190 191
  void AsyncSendBatchBarrier(const std::string& ep,
                             int64_t time_out = FLAGS_grpc_deadline) override;
Q
Qiao Longfei 已提交
192

W
Wu Yi 已提交
193 194
  void AsyncSendFetchBarrier(const std::string& ep,
                             int64_t time_out = FLAGS_grpc_deadline) override;
195

G
gongweibao 已提交
196
  void Wait() override;
Y
Yancey 已提交
197

W
Wu Yi 已提交
198 199
  void SendComplete() override;

G
gongweibao 已提交
200 201 202 203
 protected:
  void InitImpl() override;

 private:
W
Wu Yi 已提交
204 205
  // InitEventLoop should only be called by Init()
  void InitEventLoop();
G
gongweibao 已提交
206

W
Wu Yi 已提交
207
  void Proceed();
G
gongweibao 已提交
208

W
Wu Yi 已提交
209
  void AsyncSendComplete(const std::string& ep,
W
Wu Yi 已提交
210
                         int64_t time_out = FLAGS_grpc_deadline);
W
Wu Yi 已提交
211

Y
Yancey1989 已提交
212
  std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
G
gongweibao 已提交
213 214 215

 private:
  grpc::CompletionQueue cq_;
W
Wu Yi 已提交
216 217 218 219 220 221
  std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
  std::unique_ptr<std::thread> client_thread_;

  // mutex for Wait client sync
  std::mutex sync_mutex_;
  std::condition_variable sync_cond_;
Y
Yancey1989 已提交
222
  std::atomic<int64_t> req_count_{0};
W
Wu Yi 已提交
223 224 225

  // mutex for GetChannel thread safety
  std::mutex chan_mutex_;
G
gongweibao 已提交
226
  DISABLE_COPY_AND_ASSIGN(GRPCClient);
G
gongweibao 已提交
227 228
};

229
}  // namespace distributed
G
gongweibao 已提交
230 231
}  // namespace operators
}  // namespace paddle