brpc_ps_server.h 8.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
20 21
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/server.h"
T
tangwei12 已提交
22

23 24 25 26 27 28 29 30 31 32
namespace brpc {
class Controller;
}  // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google

T
tangwei12 已提交
33 34 35
namespace paddle {
namespace distributed {

36 37 38 39
class PsRequestMessage;
class PsResponseMessage;
class Table;

T
tangwei12 已提交
40 41 42 43
class BrpcPsServer : public PSServer {
 public:
  BrpcPsServer() {}
  virtual ~BrpcPsServer() {}
Z
zhaocaibei123 已提交
44 45
  virtual uint64_t Start(const std::string &ip, uint32_t port);
  virtual int32_t Stop() {
T
tangwei12 已提交
46 47 48 49 50 51 52 53
    std::unique_lock<std::mutex> lock(mutex_);
    stoped_ = true;
    cv_.notify_all();

    _server.Stop(1000);
    _server.Join();
    return 0;
  }
Z
zhaocaibei123 已提交
54
  int32_t Port();
T
tangwei12 已提交
55

Z
zhaocaibei123 已提交
56 57
  int32_t StartS2S() override;
  ::std::future<int32_t> SendPServer2PServerMsg(
Z
zhaocaibei123 已提交
58
      int msg_type, int to_pserver_id, const std::string &msg) override;
Z
zhaocaibei123 已提交
59 60 61
  int32_t ReceiveFromPServer(int msg_type,
                             int pserver_id,
                             const std::string &msg) override;
Z
zhaocaibei123 已提交
62

T
tangwei12 已提交
63
 private:
Z
zhaocaibei123 已提交
64
  virtual int32_t Initialize();
T
tangwei12 已提交
65 66 67 68 69 70 71 72
  mutable std::mutex mutex_;
  std::condition_variable cv_;
  bool stoped_ = false;
  brpc::Server _server;
  std::shared_ptr<PsBaseService> _service;
  std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};

T
tangwei12 已提交
73
class BrpcPsService;
T
tangwei12 已提交
74

T
tangwei12 已提交
75
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
76 77
    Table *table,
    const PsRequestMessage &request,
Z
zhaocaibei123 已提交
78
    PsResponseMessage &response,  // NOLINT
T
tangwei12 已提交
79 80
    brpc::Controller *cntl);

T
tangwei12 已提交
81
class BrpcPsService : public PsBaseService {
T
tangwei12 已提交
82
 public:
Z
zhaocaibei123 已提交
83
  int32_t Initialize() override;
T
tangwei12 已提交
84

Z
zhaocaibei123 已提交
85 86 87 88
  void service(::google::protobuf::RpcController *controller,
               const PsRequestMessage *request,
               PsResponseMessage *response,
               ::google::protobuf::Closure *done) override;
T
tangwei12 已提交
89 90

 private:
Z
zhaocaibei123 已提交
91
  int32_t InitializeShardInfo();
92 93
  int32_t PullDense(Table *table,
                    const PsRequestMessage &request,
Z
zhaocaibei123 已提交
94
                    PsResponseMessage &response,  // NOLINT
95 96 97
                    brpc::Controller *cntl);
  int32_t PushDense(Table *table,
                    const PsRequestMessage &request,
Z
zhaocaibei123 已提交
98
                    PsResponseMessage &response,  // NOLINT
99 100 101
                    brpc::Controller *cntl);
  int32_t PushDenseParam(Table *table,
                         const PsRequestMessage &request,
Z
zhaocaibei123 已提交
102
                         PsResponseMessage &response,  // NOLINT
103 104 105
                         brpc::Controller *cntl);
  int32_t PushSparseParam(Table *table,
                          const PsRequestMessage &request,
Z
zhaocaibei123 已提交
106
                          PsResponseMessage &response,  // NOLINT
107 108 109
                          brpc::Controller *cntl);
  int32_t PullSparse(Table *table,
                     const PsRequestMessage &request,
Z
zhaocaibei123 已提交
110
                     PsResponseMessage &response,  // NOLINT
111 112 113
                     brpc::Controller *cntl);
  int32_t PullGeoParam(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
114
                       PsResponseMessage &response,  // NOLINT
115 116 117
                       brpc::Controller *cntl);
  int32_t Barrier(Table *table,
                  const PsRequestMessage &request,
Z
zhaocaibei123 已提交
118
                  PsResponseMessage &response,  // NOLINT
119 120 121
                  brpc::Controller *cntl);
  int32_t PushSparse(Table *table,
                     const PsRequestMessage &request,
Z
zhaocaibei123 已提交
122
                     PsResponseMessage &response,  // NOLINT
123 124 125
                     brpc::Controller *cntl);
  int32_t LoadOneTable(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
126
                       PsResponseMessage &response,  // NOLINT
127 128 129
                       brpc::Controller *cntl);
  int32_t LoadAllTable(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
130
                       PsResponseMessage &response,  // NOLINT
131 132 133
                       brpc::Controller *cntl);
  int32_t SaveOneTable(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
134
                       PsResponseMessage &response,  // NOLINT
135 136 137
                       brpc::Controller *cntl);
  int32_t SaveAllTable(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
138
                       PsResponseMessage &response,  // NOLINT
139 140 141
                       brpc::Controller *cntl);
  int32_t ShrinkTable(Table *table,
                      const PsRequestMessage &request,
Z
zhaocaibei123 已提交
142
                      PsResponseMessage &response,  // NOLINT
143 144 145
                      brpc::Controller *cntl);
  int32_t ClearOneTable(Table *table,
                        const PsRequestMessage &request,
Z
zhaocaibei123 已提交
146
                        PsResponseMessage &response,  // NOLINT
147 148 149
                        brpc::Controller *cntl);
  int32_t ClearAllTable(Table *table,
                        const PsRequestMessage &request,
Z
zhaocaibei123 已提交
150
                        PsResponseMessage &response,  // NOLINT
151 152 153
                        brpc::Controller *cntl);
  int32_t StopServer(Table *table,
                     const PsRequestMessage &request,
Z
zhaocaibei123 已提交
154
                     PsResponseMessage &response,  // NOLINT
155 156 157
                     brpc::Controller *cntl);
  int32_t StartProfiler(Table *table,
                        const PsRequestMessage &request,
Z
zhaocaibei123 已提交
158
                        PsResponseMessage &response,  // NOLINT
159 160 161
                        brpc::Controller *cntl);
  int32_t StopProfiler(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
162
                       PsResponseMessage &response,  // NOLINT
163 164 165 166
                       brpc::Controller *cntl);

  int32_t PrintTableStat(Table *table,
                         const PsRequestMessage &request,
Z
zhaocaibei123 已提交
167
                         PsResponseMessage &response,  // NOLINT
168 169 170 171
                         brpc::Controller *cntl);

  int32_t PushGlobalStep(Table *table,
                         const PsRequestMessage &request,
Z
zhaocaibei123 已提交
172
                         PsResponseMessage &response,  // NOLINT
173 174 175 176
                         brpc::Controller *cntl);

  int32_t CacheShuffle(Table *table,
                       const PsRequestMessage &request,
Z
zhaocaibei123 已提交
177
                       PsResponseMessage &response,  // NOLINT
178 179 180 181
                       brpc::Controller *cntl);

  int32_t SaveCacheTable(Table *table,
                         const PsRequestMessage &request,
Z
zhaocaibei123 已提交
182
                         PsResponseMessage &response,  // NOLINT
183 184 185 186
                         brpc::Controller *cntl);

  int32_t GetCacheThreshold(Table *table,
                            const PsRequestMessage &request,
Z
zhaocaibei123 已提交
187
                            PsResponseMessage &response,  // NOLINT
Z
zhaocaibei123 已提交
188 189
                            brpc::Controller *cntl);

Z
zhaocaibei123 已提交
190 191 192 193 194 195 196 197 198 199
  int32_t Revert(Table *table,
                 const PsRequestMessage &request,
                 PsResponseMessage &response,  // NOLINT
                 brpc::Controller *cntl);

  int32_t CheckSavePrePatchDone(Table *table,
                                const PsRequestMessage &request,
                                PsResponseMessage &response,  // NOLINT
                                brpc::Controller *cntl);

T
tangwei12 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
  bool _is_initialize_shard_info;
  std::mutex _initialize_shard_mutex;
  std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
  std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
  std::vector<float> _ori_values;
};

class DownpourPServerBrpcClosure : public PServerClosure {
 public:
  DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
      : PServerClosure(callback) {
    _waiting_num = num;
    _cntls.resize(num);
    _requests.resize(num);
    _responses.resize(num);
    for (size_t i = 0; i < num; ++i) {
      _cntls[i].reset(new brpc::Controller());
    }
  }
  virtual ~DownpourPServerBrpcClosure() {}

Z
zhaocaibei123 已提交
221
  void Run() override {
T
tangwei12 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
    if (_waiting_num.fetch_sub(1) == 1) {
      _callback(this);
      delete this;
    }
  }
  PsRequestMessage *request(size_t i) { return &_requests[i]; }
  PsResponseMessage *response(size_t i) { return &_responses[i]; }
  brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
  int check_response(size_t request_idx, int cmd_id) { return 1; }
  int check_save_response(size_t request_idx, int cmd_id) { return 1; }

 private:
  std::atomic<int32_t> _waiting_num;
  std::vector<PsRequestMessage> _requests;
  std::vector<PsResponseMessage> _responses;
  std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
}  // namespace distributed
}  // namespace paddle