brpc_ps_client.h 11.8 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Z
zhaocaibei123 已提交
17
#include <ThreadPool.h>
T
tangwei12 已提交
18 19 20 21 22 23 24
#include <memory>
#include <string>
#include <vector>

#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
25 26
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
Z
zhaocaibei123 已提交
27
#include "paddle/fluid/framework/channel.h"
28 29 30
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
T
tangwei12 已提交
31

32 33 34 35 36 37 38 39 40 41 42
namespace brpc {
class Channel;
class Controller;
}  // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
}  // namespace protobuf
}  // namespace google

T
tangwei12 已提交
43 44 45
namespace paddle {
namespace distributed {

46 47
struct Region;

T
tangwei12 已提交
48 49 50 51 52 53 54 55 56 57
class DownpourPsClientService : public PsService {
 public:
  DownpourPsClientService() {}
  virtual ~DownpourPsClientService() {}

  virtual int32_t configure(PSClient *client, size_t rank_id) {
    _client = client;
    _rank = rank_id;
    return 0;
  }
Z
zhaocaibei123 已提交
58 59 60
  void service(::google::protobuf::RpcController *controller,
               const PsRequestMessage *request, PsResponseMessage *response,
               ::google::protobuf::Closure *done) override;
T
tangwei12 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

 protected:
  size_t _rank;
  PSClient *_client;
};

class DownpourBrpcClosure : public PSClientClosure {
 public:
  DownpourBrpcClosure(size_t num, PSClientCallBack callback)
      : PSClientClosure(callback) {
    _waiting_num = num;

    _cntls.resize(num);
    _requests.resize(num);
    _responses.resize(num);
    for (size_t i = 0; i < num; ++i) {
      _cntls[i].reset(new brpc::Controller());
    }
  }
  virtual ~DownpourBrpcClosure() {}
Z
zhaocaibei123 已提交
81
  void Run() override {
T
tangwei12 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    if (_waiting_num.fetch_sub(1) == 1) {
      _callback(this);
      delete this;
    }
  }
  PsRequestMessage *request(size_t i) { return &_requests[i]; }
  PsResponseMessage *response(size_t i) { return &_responses[i]; }
  brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
  int check_response(size_t request_idx, int cmd_id);
  int check_save_response(size_t request_idx, int cmd_id);
  std::string get_response(size_t request_idx, int cmd_id);

 private:
  std::atomic<int32_t> _waiting_num;
  std::vector<PsRequestMessage> _requests;
  std::vector<PsResponseMessage> _responses;
  std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};

Z
zhaocaibei123 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
struct SharedSparsePushData {
  SharedSparsePushData() {}
  ~SharedSparsePushData() noexcept {}
  size_t kv_num;
  std::vector<uint64_t> key_list;
  std::vector<std::string> value_list;
};
struct SparsePushTaskData {
  std::vector<SharedSparsePushData> shared_data;  // sparse数据按key hash分片
};

// push sparse 对象池
struct SparseTaskPool {
  std::shared_ptr<SparsePushTaskData> get() {
    std::lock_guard<std::mutex> lock(_mutex);
    if (_pool.empty()) {
      return std::make_shared<SparsePushTaskData>();
    } else {
      auto ret = _pool.back();
      _pool.pop_back();
      return ret;
    }
  }
  void push(std::shared_ptr<SparsePushTaskData> data) {
    std::lock_guard<std::mutex> lock(_mutex);
    _pool.push_back(std::move(data));
  }
  std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
  std::mutex _mutex;
};

T
tangwei12 已提交
132 133
template <class T>
struct array_deleter {
Z
zhaocaibei123 已提交
134
  void operator()(T *&x) const { delete[] x; }  // NOLINT
T
tangwei12 已提交
135 136 137 138 139 140
};

class BrpcPsClient : public PSClient {
 public:
  BrpcPsClient() {}
  virtual ~BrpcPsClient() {
Z
zhaocaibei123 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    if (_running) {
      flush();
      _running = false;
    }
    if (_async_push_dense_thread.joinable()) {
      _async_push_dense_thread.join();
    }
    if (_async_push_sparse_thread.joinable()) {
      _async_push_sparse_thread.join();
    }
    if (_server_started) {
      _server.Stop(1000);
      _server.Join();
      _server_started = false;
    }
T
tangwei12 已提交
156 157 158
  }
  virtual int32_t create_client2client_connection(
      int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
Z
zhaocaibei123 已提交
159 160 161 162 163 164
  std::future<int32_t> shrink(uint32_t table_id,
                              const std::string threshold) override;
  std::future<int32_t> load(const std::string &epoch,
                            const std::string &mode) override;
  std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
165

Y
yaoxuefeng 已提交
166 167
  std::future<int32_t> Load(const LoadSaveContext &load_context) override;

Z
zhaocaibei123 已提交
168 169
  std::future<int32_t> save(const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
170

Z
zhaocaibei123 已提交
171 172
  std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
                            const std::string &mode) override;
T
tangwei12 已提交
173

Y
yaoxuefeng 已提交
174 175 176
  virtual std::future<int32_t> Save(
      const LoadSaveContext &save_context) override;

Z
zhaocaibei123 已提交
177
  std::future<int32_t> clear() override;
T
tangwei12 已提交
178

Z
zhaocaibei123 已提交
179
  std::future<int32_t> clear(uint32_t table_id) override;
T
tangwei12 已提交
180

Z
zhaocaibei123 已提交
181
  std::future<int32_t> stop_server() override;
T
tangwei12 已提交
182

Z
zhaocaibei123 已提交
183 184
  std::future<int32_t> start_profiler() override;
  std::future<int32_t> stop_profiler() override;
T
tangwei12 已提交
185

Z
zhaocaibei123 已提交
186
  void finalize_worker() override;
T
tangwei12 已提交
187 188 189 190 191 192 193 194

  virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
                                          size_t table_id);

  virtual std::future<int32_t> push_dense_param(const Region *regions,
                                                size_t region_num,
                                                size_t table_id);

Z
zhaocaibei123 已提交
195 196 197
  virtual std::future<int32_t> push_dense(const Region *regions,
                                          size_t region_num, size_t table_id);
  void push_dense_task_consume();
T
tangwei12 已提交
198 199
  virtual std::future<int32_t> pull_sparse(float **select_values,
                                           size_t table_id,
200 201
                                           const uint64_t *keys, size_t num,
                                           bool is_training);
Z
zhaocaibei123 已提交
202 203 204 205
  virtual std::future<int32_t> pull_sparse_param(float **select_values,
                                                 size_t table_id,
                                                 const uint64_t *keys,
                                                 size_t num, bool is_training);
T
tangwei12 已提交
206

Y
yaoxuefeng 已提交
207 208 209 210
  virtual std::future<int32_t> Pull(RequestContext &pull_context) override;

  virtual std::future<int32_t> Push(RequestContext &push_context) override;

T
tangwei12 已提交
211 212 213 214 215 216 217 218
  virtual std::future<int32_t> print_table_stat(uint32_t table_id);

  virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);

  virtual std::future<int32_t> pull_geo_param(size_t table_id,
                                              std::vector<float> *values,
                                              std::vector<uint64_t> *keys,
                                              int pserver_idx);
219 220 221
  virtual std::future<int32_t> push_global_step(int table_id,
                                                int64_t *total_send_data,
                                                void *done);
T
tangwei12 已提交
222 223
  virtual std::future<int32_t> flush();

Z
zhaocaibei123 已提交
224 225
  std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
                                              const std::string &msg) override;
T
tangwei12 已提交
226

227 228 229 230
  // for local save sparse
  virtual int32_t recv_and_save_table(const uint64_t table_id,
                                      const std::string &path);

Z
zhaocaibei123 已提交
231 232 233
  void print_queue_size();
  void print_queue_size_thread();

S
seemingwang 已提交
234 235 236 237 238 239 240 241 242 243 244
 protected:
  virtual size_t get_server_nums() { return _server_channels.size(); }
  inline brpc::Channel *get_sparse_channel(size_t server_id) {
    return _server_channels[server_id][0].get();
  }
  inline brpc::Channel *get_dense_channel(size_t server_id) {
    return _server_channels[server_id][1].get();
  }
  inline brpc::Channel *get_cmd_channel(size_t server_id) {
    return _server_channels[server_id][2].get();
  }
Z
zhaocaibei123 已提交
245
  int32_t initialize() override;
T
tangwei12 已提交
246

S
seemingwang 已提交
247 248 249
 private:
  // virtual int32_t initialize() override;

T
tangwei12 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262
  inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
                                      uint32_t shard_num) {
    return dense_dim_total / shard_num + 1;
  }

  std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
                                const std::vector<std::string> &param);

  std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
                                     const std::vector<std::string> &param);

  bool _running = false;
  bool _flushing = false;
Z
zhaocaibei123 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
  std::atomic<uint32_t> _async_call_num;  // 异步请求计数

  // 异步push dense task
  std::thread _async_push_dense_thread;
  typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
  std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
      _push_dense_task_queue_map;
  // 异步push sparse task
  std::thread _async_push_sparse_thread;
  typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
  std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
      _push_sparse_task_queue_map;
  std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;

  std::thread _print_thread;

  int push_sparse_async_shard_merge(
      std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,       // NOLINT
      std::vector<int> &request_kv_num, int table_id, int shard_idx,  // NOLINT
      ValueAccessor *accessor);

  int push_sparse_async_shard_push(
      std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,       // NOLINT
      std::vector<int> &request_kv_num, int table_id, int shard_idx,  // NOLINT
      DownpourBrpcClosure *closure, ValueAccessor *accessor);

  SparseTaskPool _sparse_task_pool;
T
tangwei12 已提交
290 291 292 293 294

  std::vector<std::shared_ptr<brpc::Channel>>
      _client_channels;  // client2client
  std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
      _server_channels;  // client2server
Z
zhaocaibei123 已提交
295 296 297 298 299 300 301 302 303 304 305 306
  std::future<int32_t> push_dense_raw_gradient(int table_id,
                                               float *total_send_data,
                                               size_t total_send_data_size,
                                               void *done) override;

  std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
                                                const uint64_t *keys,
                                                const float **update_values,
                                                size_t num,
                                                void *done) override;

  std::future<int32_t> push_sparse_raw_gradient_partial(
T
tangwei12 已提交
307 308 309
      size_t table_id, const uint64_t *keys, const float **update_values,
      uint32_t num, void *done, int pserver_idx) override;

Z
zhaocaibei123 已提交
310 311 312 313 314 315 316
  std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
                                         const float **update_values,
                                         size_t num, void *done) override;
  std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
                                   const float **update_values,
                                   size_t num) override;
  void push_sparse_task_consume();
T
tangwei12 已提交
317 318 319 320

 private:
  int32_t start_client_service();

Z
zhaocaibei123 已提交
321 322 323 324
  void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task,  // NOLINT
                               float *total_send_data,
                               size_t total_send_data_size,
                               DownpourBrpcClosure *closure);
T
tangwei12 已提交
325 326 327 328 329
  float _mae = 0;
  float _mse = 0;
  uint16_t _push_times = 0;
  brpc::Server _server;
  DownpourPsClientService _service;
Z
zhaocaibei123 已提交
330
  bool _server_started = false;
T
tangwei12 已提交
331 332 333 334
  std::atomic_uint grad_num_{0};
};
}  // namespace distributed
}  // namespace paddle