brpc_ps_client.h 11.5 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Z
zhaocaibei123 已提交
17
#include <ThreadPool.h>
T
tangwei12 已提交
18 19 20 21 22 23 24
#include <memory>
#include <string>
#include <vector>

#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
25 26
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
Z
zhaocaibei123 已提交
27
#include "paddle/fluid/framework/channel.h"
28 29 30
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
T
tangwei12 已提交
31

32 33 34 35 36 37 38 39 40 41 42
namespace brpc {
class Channel;
class Controller;
}  // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google

T
tangwei12 已提交
43 44 45
namespace paddle {
namespace distributed {

46 47
struct Region;

T
tangwei12 已提交
48 49 50 51 52 53 54 55 56 57
class DownpourPsClientService : public PsService {
 public:
  DownpourPsClientService() {}
  virtual ~DownpourPsClientService() {}

  virtual int32_t configure(PSClient *client, size_t rank_id) {
    _client = client;
    _rank = rank_id;
    return 0;
  }
Z
zhaocaibei123 已提交
58 59 60
  void service(::google::protobuf::RpcController *controller,
               const PsRequestMessage *request, PsResponseMessage *response,
               ::google::protobuf::Closure *done) override;
T
tangwei12 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

 protected:
  size_t _rank;
  PSClient *_client;
};

class DownpourBrpcClosure : public PSClientClosure {
 public:
  DownpourBrpcClosure(size_t num, PSClientCallBack callback)
      : PSClientClosure(callback) {
    _waiting_num = num;

    _cntls.resize(num);
    _requests.resize(num);
    _responses.resize(num);
    for (size_t i = 0; i < num; ++i) {
      _cntls[i].reset(new brpc::Controller());
    }
  }
  virtual ~DownpourBrpcClosure() {}
Z
zhaocaibei123 已提交
81
  void Run() override {
T
tangwei12 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    if (_waiting_num.fetch_sub(1) == 1) {
      _callback(this);
      delete this;
    }
  }
  PsRequestMessage *request(size_t i) { return &_requests[i]; }
  PsResponseMessage *response(size_t i) { return &_responses[i]; }
  brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
  int check_response(size_t request_idx, int cmd_id);
  int check_save_response(size_t request_idx, int cmd_id);
  std::string get_response(size_t request_idx, int cmd_id);

 private:
  std::atomic<int32_t> _waiting_num;
  std::vector<PsRequestMessage> _requests;
  std::vector<PsResponseMessage> _responses;
  std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};

Z
zhaocaibei123 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
struct SharedSparsePushData {
  SharedSparsePushData() {}
  ~SharedSparsePushData() noexcept {}
  size_t kv_num;
  std::vector<uint64_t> key_list;
  std::vector<std::string> value_list;
};
struct SparsePushTaskData {
  std::vector<SharedSparsePushData> shared_data;  // sparse数据按key hash分片
};

// push sparse 对象池
struct SparseTaskPool {
  std::shared_ptr<SparsePushTaskData> get() {
    std::lock_guard<std::mutex> lock(_mutex);
    if (_pool.empty()) {
      return std::make_shared<SparsePushTaskData>();
    } else {
      auto ret = _pool.back();
      _pool.pop_back();
      return ret;
    }
  }
  void push(std::shared_ptr<SparsePushTaskData> data) {
    std::lock_guard<std::mutex> lock(_mutex);
    _pool.push_back(std::move(data));
  }
  std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
  std::mutex _mutex;
};

T
tangwei12 已提交
132 133
template <class T>
struct array_deleter {
Z
zhaocaibei123 已提交
134
  void operator()(T *&x) const { delete[] x; }  // NOLINT
T
tangwei12 已提交
135 136 137 138 139 140
};

class BrpcPsClient : public PSClient {
 public:
  BrpcPsClient() {}
  virtual ~BrpcPsClient() {
Z
zhaocaibei123 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    if (_running) {
      flush();
      _running = false;
    }
    if (_async_push_dense_thread.joinable()) {
      _async_push_dense_thread.join();
    }
    if (_async_push_sparse_thread.joinable()) {
      _async_push_sparse_thread.join();
    }
    if (_server_started) {
      _server.Stop(1000);
      _server.Join();
      _server_started = false;
    }
T
tangwei12 已提交
156 157 158
  }
  virtual int32_t create_client2client_connection(
      int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
Z
zhaocaibei123 已提交
159 160 161 162 163 164
  std::future<int32_t> shrink(uint32_t table_id,
                              const std::string threshold) override;
  std::future<int32_t> load(const std::string &epoch,
                            const std::string &mode) override;
  std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
165

Z
zhaocaibei123 已提交
166 167
  std::future<int32_t> save(const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
168

Z
zhaocaibei123 已提交
169 170
  std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
171

Z
zhaocaibei123 已提交
172
  std::future<int32_t> clear() override;
T
tangwei12 已提交
173

Z
zhaocaibei123 已提交
174
  std::future<int32_t> clear(uint32_t table_id) override;
T
tangwei12 已提交
175

Z
zhaocaibei123 已提交
176
  std::future<int32_t> stop_server() override;
T
tangwei12 已提交
177

Z
zhaocaibei123 已提交
178 179
  std::future<int32_t> start_profiler() override;
  std::future<int32_t> stop_profiler() override;
T
tangwei12 已提交
180

Z
zhaocaibei123 已提交
181
  void finalize_worker() override;
T
tangwei12 已提交
182 183 184 185 186 187 188 189

  virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
                                          size_t table_id);

  virtual std::future<int32_t> push_dense_param(const Region *regions,
                                                size_t region_num,
                                                size_t table_id);

Z
zhaocaibei123 已提交
190 191 192
  virtual std::future<int32_t> push_dense(const Region *regions,
                                          size_t region_num, size_t table_id);
  void push_dense_task_consume();
T
tangwei12 已提交
193 194
  virtual std::future<int32_t> pull_sparse(float **select_values,
                                           size_t table_id,
195 196
                                           const uint64_t *keys, size_t num,
                                           bool is_training);
Z
zhaocaibei123 已提交
197 198 199 200
  virtual std::future<int32_t> pull_sparse_param(float **select_values,
                                                 size_t table_id,
                                                 const uint64_t *keys,
                                                 size_t num, bool is_training);
T
tangwei12 已提交
201 202 203 204 205 206 207 208 209

  virtual std::future<int32_t> print_table_stat(uint32_t table_id);

  virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);

  virtual std::future<int32_t> pull_geo_param(size_t table_id,
                                              std::vector<float> *values,
                                              std::vector<uint64_t> *keys,
                                              int pserver_idx);
210 211 212
  virtual std::future<int32_t> push_global_step(int table_id,
                                                int64_t *total_send_data,
                                                void *done);
T
tangwei12 已提交
213 214
  virtual std::future<int32_t> flush();

Z
zhaocaibei123 已提交
215 216
  std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
                                              const std::string &msg) override;
T
tangwei12 已提交
217

218 219 220 221
  // for local save sparse
  virtual int32_t recv_and_save_table(const uint64_t table_id,
                                      const std::string &path);

Z
zhaocaibei123 已提交
222 223 224
  void print_queue_size();
  void print_queue_size_thread();

S
seemingwang 已提交
225 226 227 228 229 230 231 232 233 234 235
 protected:
  virtual size_t get_server_nums() { return _server_channels.size(); }
  inline brpc::Channel *get_sparse_channel(size_t server_id) {
    return _server_channels[server_id][0].get();
  }
  inline brpc::Channel *get_dense_channel(size_t server_id) {
    return _server_channels[server_id][1].get();
  }
  inline brpc::Channel *get_cmd_channel(size_t server_id) {
    return _server_channels[server_id][2].get();
  }
Z
zhaocaibei123 已提交
236
  int32_t initialize() override;
T
tangwei12 已提交
237

S
seemingwang 已提交
238 239 240
 private:
  // virtual int32_t initialize() override;

T
tangwei12 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253
  inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
                                      uint32_t shard_num) {
    return dense_dim_total / shard_num + 1;
  }

  std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
                                const std::vector<std::string> &param);

  std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
                                     const std::vector<std::string> &param);

  bool _running = false;
  bool _flushing = false;
Z
zhaocaibei123 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
  std::atomic<uint32_t> _async_call_num;  // 异步请求计数

  // 异步push dense task
  std::thread _async_push_dense_thread;
  typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
  std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
      _push_dense_task_queue_map;
  // 异步push sparse task
  std::thread _async_push_sparse_thread;
  typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
  std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
      _push_sparse_task_queue_map;
  std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;

  std::thread _print_thread;

  int push_sparse_async_shard_merge(
      std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,       // NOLINT
      std::vector<int> &request_kv_num, int table_id, int shard_idx,  // NOLINT
      ValueAccessor *accessor);

  int push_sparse_async_shard_push(
      std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,       // NOLINT
      std::vector<int> &request_kv_num, int table_id, int shard_idx,  // NOLINT
      DownpourBrpcClosure *closure, ValueAccessor *accessor);

  SparseTaskPool _sparse_task_pool;
T
tangwei12 已提交
281 282 283 284 285

  std::vector<std::shared_ptr<brpc::Channel>>
      _client_channels;  // client2client
  std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
      _server_channels;  // client2server
Z
zhaocaibei123 已提交
286 287 288 289 290 291 292 293 294 295 296 297
  std::future<int32_t> push_dense_raw_gradient(int table_id,
                                               float *total_send_data,
                                               size_t total_send_data_size,
                                               void *done) override;

  std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
                                                const uint64_t *keys,
                                                const float **update_values,
                                                size_t num,
                                                void *done) override;

  std::future<int32_t> push_sparse_raw_gradient_partial(
T
tangwei12 已提交
298 299 300
      size_t table_id, const uint64_t *keys, const float **update_values,
      uint32_t num, void *done, int pserver_idx) override;

Z
zhaocaibei123 已提交
301 302 303 304 305 306 307
  std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
                                         const float **update_values,
                                         size_t num, void *done) override;
  std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
                                   const float **update_values,
                                   size_t num) override;
  void push_sparse_task_consume();
T
tangwei12 已提交
308 309 310 311

 private:
  int32_t start_client_service();

Z
zhaocaibei123 已提交
312 313 314 315
  void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task,  // NOLINT
                               float *total_send_data,
                               size_t total_send_data_size,
                               DownpourBrpcClosure *closure);
T
tangwei12 已提交
316 317 318 319 320
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
  brpc::Server _server;
  DownpourPsClientService _service;
Z
zhaocaibei123 已提交
321
  bool _server_started = false;
T
tangwei12 已提交
322 323 324 325
  std::atomic_uint grad_num_{0};
};
}  // namespace distributed
}  // namespace paddle