brpc_ps_client.h 11.2 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Z
zhaocaibei123 已提交
17
#include <ThreadPool.h>
T
tangwei12 已提交
18 19 20 21 22 23 24
#include <memory>
#include <string>
#include <vector>

#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
25 26
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
Z
zhaocaibei123 已提交
27
#include "paddle/fluid/framework/channel.h"
28 29 30
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
T
tangwei12 已提交
31

32 33 34 35 36 37 38 39 40 41 42
namespace brpc {
class Channel;
class Controller;
}  // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google

T
tangwei12 已提交
43 44 45
namespace paddle {
namespace distributed {

46 47
struct Region;

T
tangwei12 已提交
48 49 50 51 52 53 54 55 56 57
class DownpourPsClientService : public PsService {
 public:
  DownpourPsClientService() {}
  virtual ~DownpourPsClientService() {}

  virtual int32_t configure(PSClient *client, size_t rank_id) {
    _client = client;
    _rank = rank_id;
    return 0;
  }
Z
zhaocaibei123 已提交
58 59 60
  void service(::google::protobuf::RpcController *controller,
               const PsRequestMessage *request, PsResponseMessage *response,
               ::google::protobuf::Closure *done) override;
T
tangwei12 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

 protected:
  size_t _rank;
  PSClient *_client;
};

class DownpourBrpcClosure : public PSClientClosure {
 public:
  DownpourBrpcClosure(size_t num, PSClientCallBack callback)
      : PSClientClosure(callback) {
    _waiting_num = num;

    _cntls.resize(num);
    _requests.resize(num);
    _responses.resize(num);
    for (size_t i = 0; i < num; ++i) {
      _cntls[i].reset(new brpc::Controller());
    }
  }
  virtual ~DownpourBrpcClosure() {}
Z
zhaocaibei123 已提交
81
  void Run() override {
T
tangwei12 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    if (_waiting_num.fetch_sub(1) == 1) {
      _callback(this);
      delete this;
    }
  }
  PsRequestMessage *request(size_t i) { return &_requests[i]; }
  PsResponseMessage *response(size_t i) { return &_responses[i]; }
  brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
  int check_response(size_t request_idx, int cmd_id);
  int check_save_response(size_t request_idx, int cmd_id);
  std::string get_response(size_t request_idx, int cmd_id);

 private:
  std::atomic<int32_t> _waiting_num;
  std::vector<PsRequestMessage> _requests;
  std::vector<PsResponseMessage> _responses;
  std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};

Z
zhaocaibei123 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
struct SharedSparsePushData {
  SharedSparsePushData() {}
  ~SharedSparsePushData() noexcept {}
  size_t kv_num;
  std::vector<uint64_t> key_list;
  std::vector<std::string> value_list;
};
struct SparsePushTaskData {
  std::vector<SharedSparsePushData> shared_data;  // sparse数据按key hash分片
};

// push sparse 对象池
struct SparseTaskPool {
  std::shared_ptr<SparsePushTaskData> get() {
    std::lock_guard<std::mutex> lock(_mutex);
    if (_pool.empty()) {
      return std::make_shared<SparsePushTaskData>();
    } else {
      auto ret = _pool.back();
      _pool.pop_back();
      return ret;
    }
  }
  void push(std::shared_ptr<SparsePushTaskData> data) {
    std::lock_guard<std::mutex> lock(_mutex);
    _pool.push_back(std::move(data));
  }
  std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
  std::mutex _mutex;
};

T
tangwei12 已提交
132 133
template <class T>
struct array_deleter {
Z
zhaocaibei123 已提交
134
  void operator()(T *&x) const { delete[] x; }  // NOLINT
T
tangwei12 已提交
135 136 137 138 139 140
};

class BrpcPsClient : public PSClient {
 public:
  BrpcPsClient() {}
  virtual ~BrpcPsClient() {
Z
zhaocaibei123 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    if (_running) {
      flush();
      _running = false;
    }
    if (_async_push_dense_thread.joinable()) {
      _async_push_dense_thread.join();
    }
    if (_async_push_sparse_thread.joinable()) {
      _async_push_sparse_thread.join();
    }
    if (_server_started) {
      _server.Stop(1000);
      _server.Join();
      _server_started = false;
    }
T
tangwei12 已提交
156 157 158
  }
  virtual int32_t create_client2client_connection(
      int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
Z
zhaocaibei123 已提交
159 160 161 162 163 164
  std::future<int32_t> shrink(uint32_t table_id,
                              const std::string threshold) override;
  std::future<int32_t> load(const std::string &epoch,
                            const std::string &mode) override;
  std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
165

Z
zhaocaibei123 已提交
166 167
  std::future<int32_t> save(const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
168

Z
zhaocaibei123 已提交
169 170
  std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
171

Z
zhaocaibei123 已提交
172
  std::future<int32_t> clear() override;
T
tangwei12 已提交
173

Z
zhaocaibei123 已提交
174
  std::future<int32_t> clear(uint32_t table_id) override;
T
tangwei12 已提交
175

Z
zhaocaibei123 已提交
176
  std::future<int32_t> stop_server() override;
T
tangwei12 已提交
177

Z
zhaocaibei123 已提交
178 179
  std::future<int32_t> start_profiler() override;
  std::future<int32_t> stop_profiler() override;
T
tangwei12 已提交
180

Z
zhaocaibei123 已提交
181
  void finalize_worker() override;
T
tangwei12 已提交
182 183 184 185 186 187 188 189

  virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
                                          size_t table_id);

  virtual std::future<int32_t> push_dense_param(const Region *regions,
                                                size_t region_num,
                                                size_t table_id);

Z
zhaocaibei123 已提交
190 191 192
  virtual std::future<int32_t> push_dense(const Region *regions,
                                          size_t region_num, size_t table_id);
  void push_dense_task_consume();
T
tangwei12 已提交
193 194
  virtual std::future<int32_t> pull_sparse(float **select_values,
                                           size_t table_id,
195 196
                                           const uint64_t *keys, size_t num,
                                           bool is_training);
T
tangwei12 已提交
197 198 199 200 201 202 203 204 205

  virtual std::future<int32_t> print_table_stat(uint32_t table_id);

  virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);

  virtual std::future<int32_t> pull_geo_param(size_t table_id,
                                              std::vector<float> *values,
                                              std::vector<uint64_t> *keys,
                                              int pserver_idx);
206 207 208
  virtual std::future<int32_t> push_global_step(int table_id,
                                                int64_t *total_send_data,
                                                void *done);
T
tangwei12 已提交
209 210
  virtual std::future<int32_t> flush();

Z
zhaocaibei123 已提交
211 212
  std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
                                              const std::string &msg) override;
T
tangwei12 已提交
213

214 215 216 217
  // for local save sparse
  virtual int32_t recv_and_save_table(const uint64_t table_id,
                                      const std::string &path);

Z
zhaocaibei123 已提交
218 219 220
  void print_queue_size();
  void print_queue_size_thread();

S
seemingwang 已提交
221 222 223 224 225 226 227 228 229 230 231
 protected:
  virtual size_t get_server_nums() { return _server_channels.size(); }
  inline brpc::Channel *get_sparse_channel(size_t server_id) {
    return _server_channels[server_id][0].get();
  }
  inline brpc::Channel *get_dense_channel(size_t server_id) {
    return _server_channels[server_id][1].get();
  }
  inline brpc::Channel *get_cmd_channel(size_t server_id) {
    return _server_channels[server_id][2].get();
  }
Z
zhaocaibei123 已提交
232
  int32_t initialize() override;
T
tangwei12 已提交
233

S
seemingwang 已提交
234 235 236
 private:
  // virtual int32_t initialize() override;

T
tangwei12 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249
  inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
                                      uint32_t shard_num) {
    return dense_dim_total / shard_num + 1;
  }

  std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
                                const std::vector<std::string> &param);

  std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
                                     const std::vector<std::string> &param);

  bool _running = false;
  bool _flushing = false;
Z
zhaocaibei123 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  std::atomic<uint32_t> _async_call_num;  // 异步请求计数

  // 异步push dense task
  std::thread _async_push_dense_thread;
  typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
  std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
      _push_dense_task_queue_map;
  // 异步push sparse task
  std::thread _async_push_sparse_thread;
  typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
  std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
      _push_sparse_task_queue_map;
  std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;

  std::thread _print_thread;

  int push_sparse_async_shard_merge(
      std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,       // NOLINT
      std::vector<int> &request_kv_num, int table_id, int shard_idx,  // NOLINT
      ValueAccessor *accessor);

  int push_sparse_async_shard_push(
      std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,       // NOLINT
      std::vector<int> &request_kv_num, int table_id, int shard_idx,  // NOLINT
      DownpourBrpcClosure *closure, ValueAccessor *accessor);

  SparseTaskPool _sparse_task_pool;
T
tangwei12 已提交
277 278 279 280 281

  std::vector<std::shared_ptr<brpc::Channel>>
      _client_channels;  // client2client
  std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
      _server_channels;  // client2server
Z
zhaocaibei123 已提交
282 283 284 285 286 287 288 289 290 291 292 293
  std::future<int32_t> push_dense_raw_gradient(int table_id,
                                               float *total_send_data,
                                               size_t total_send_data_size,
                                               void *done) override;

  std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
                                                const uint64_t *keys,
                                                const float **update_values,
                                                size_t num,
                                                void *done) override;

  std::future<int32_t> push_sparse_raw_gradient_partial(
T
tangwei12 已提交
294 295 296
      size_t table_id, const uint64_t *keys, const float **update_values,
      uint32_t num, void *done, int pserver_idx) override;

Z
zhaocaibei123 已提交
297 298 299 300 301 302 303
  std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
                                         const float **update_values,
                                         size_t num, void *done) override;
  std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
                                   const float **update_values,
                                   size_t num) override;
  void push_sparse_task_consume();
T
tangwei12 已提交
304 305 306 307

 private:
  int32_t start_client_service();

Z
zhaocaibei123 已提交
308 309 310 311
  void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task,  // NOLINT
                               float *total_send_data,
                               size_t total_send_data_size,
                               DownpourBrpcClosure *closure);
T
tangwei12 已提交
312 313 314 315 316
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
  brpc::Server _server;
  DownpourPsClientService _service;
Z
zhaocaibei123 已提交
317
  bool _server_started = false;
T
tangwei12 已提交
318 319 320 321
  std::atomic_uint grad_num_{0};
};
}  // namespace distributed
}  // namespace paddle