fleet_wrapper.h 15.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <memory>
#ifdef PADDLE_WITH_PSLIB
19
#include <archive.h>
D
dongdaxiang 已提交
20
#include <pslib.h>
21
#endif
22
#include <ThreadPool.h>
23

24
#include <atomic>
X
xujiaqi01 已提交
25
#include <ctime>
D
dongdaxiang 已提交
26
#include <map>
D
dongdaxiang 已提交
27
#include <random>
28
#include <string>
29
#include <unordered_map>
30
#include <vector>
31

T
Thunderbrook 已提交
32
#include "paddle/fluid/framework/heter_util.h"
D
dongdaxiang 已提交
33
#include "paddle/fluid/framework/program_desc.h"
34
#include "paddle/fluid/framework/scope.h"
35
#include "paddle/fluid/framework/tensor.h"
36 37
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
T
Thunderbrook 已提交
38
#ifdef PADDLE_WITH_HETERPS
39
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
T
Thunderbrook 已提交
40
#endif
41

W
wanghuancoder 已提交
42 43 44 45 46 47
namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
namespace paddle {
namespace framework {

// A wrapper class for pslib.h, this class follows Singleton pattern
// i.e. only initialized once in the current process
// Example:
//    std::shared_ptr<FleetWrapper> fleet_ptr =
//         FleetWrapper::GetInstance();
//    string dist_desc;
//    fleet_ptr->InitServer(dist_desc, 0);
// interface design principles:
// Pull
//   Sync: PullSparseVarsSync
//   Async: PullSparseVarsAsync(not implemented currently)
// Push
//   Sync: PushSparseVarsSync
64 65
//   Async: PushSparseVarsAsync(not implemented currently)
//   Async: PushSparseVarsWithLabelAsync(with special usage)
66 67 68 69 70 71 72
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status

class FleetWrapper {
 public:
  virtual ~FleetWrapper() {}
73 74 75 76
  FleetWrapper() {
    scale_sparse_gradient_with_batch_size_ = true;
    // trainer sleep some time for pslib core dump
    sleep_seconds_before_fail_exit_ = 300;
77 78 79 80 81 82
    // pslib request server timeout ms
    client2client_request_timeout_ms_ = 500000;
    // pslib connect server timeout_ms
    client2client_connect_timeout_ms_ = 10000;
    // pslib request max retry
    client2client_max_retry_ = 3;
83
    pull_local_thread_num_ = 25;
84
  }
85

X
xujiaqi01 已提交
86
  // set client to client communication config
87 88
  void SetClient2ClientConfig(int request_timeout_ms,
                              int connect_timeout_ms,
89 90
                              int max_retry);

91 92 93
  void SetPullLocalThreadNum(int thread_num) {
    pull_local_thread_num_ = thread_num;
  }
94

T
Thunderbrook 已提交
95
#ifdef PADDLE_WITH_PSLIB
96 97
  void HeterPullSparseVars(int workerid,
                           std::shared_ptr<HeterTask> task,
T
Thunderbrook 已提交
98 99 100 101 102 103
                           const uint64_t table_id,
                           const std::vector<std::string>& var_names,
                           int fea_dim,
                           const std::vector<std::string>& var_emb_names);

  void HeterPushSparseVars(
104 105 106 107 108 109
      std::shared_ptr<HeterTask> task,
      const Scope& scope,
      const uint64_t table_id,
      const std::vector<std::string>& sparse_key_names,
      const std::vector<std::string>& sparse_grad_names,
      const int emb_dim,
T
Thunderbrook 已提交
110
      std::vector<::std::future<int32_t>>* push_sparse_status,
111 112 113
      const bool use_cvm,
      const bool dump_slot,
      const bool no_cvm);
T
Thunderbrook 已提交
114 115 116 117 118
#endif

  typedef std::function<void(int, int)> HeterCallBackFunc;
  int RegisterHeterCallback(HeterCallBackFunc handler);

X
xujiaqi01 已提交
119
  // Pull sparse variables from server in sync mode
120
  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
121
  // Param<out>: fea_values
122 123
  void PullSparseVarsSync(const Scope& scope,
                          const uint64_t table_id,
124 125 126
                          const std::vector<std::string>& var_names,
                          std::vector<uint64_t>* fea_keys,
                          std::vector<std::vector<float>>* fea_values,
127 128
                          int fea_dim,
                          const std::vector<std::string>& var_emb_names);
129 130 131 132

  // Pull sparse variables from server in async mode
  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim
  // Param<out>: fea_values std::future
133
  std::future<int32_t> PullSparseVarsAsync(
134 135
      const Scope& scope,
      const uint64_t table_id,
136 137
      const std::vector<std::string>& var_names,
      std::vector<uint64_t>* fea_keys,
138 139
      std::vector<std::vector<float>>* fea_values,
      int fea_dim);
140 141 142

  // Pull sparse variables from server in sync mode
  // pull immediately to tensors
143 144 145 146
  void PullSparseToTensorSync(const uint64_t table_id,
                              int fea_dim,
                              uint64_t padding_id,
                              platform::Place place,
147 148 149
                              std::vector<const LoDTensor*>* inputs,  // NOLINT
                              std::vector<LoDTensor*>* outputs);      // NOLINT

X
xujiaqi01 已提交
150
  // pull dense variables from server in sync mod
151 152
  // Param<in>: scope, table_id, var_names
  // Param<out>: void
153 154
  void PullDenseVarsSync(const Scope& scope,
                         const uint64_t table_id,
155 156
                         const std::vector<std::string>& var_names);

X
xujiaqi01 已提交
157 158 159
  // pull dense variables from server in async mod
  // Param<in>: scope, table_id, var_names
  // Param<out>: pull_dense_status
160
  void PullDenseVarsAsync(
161 162
      const Scope& scope,
      const uint64_t table_id,
163
      const std::vector<std::string>& var_names,
164 165
      std::vector<::std::future<int32_t>>* pull_dense_status,
      bool in_cpu);
166

X
xujiaqi01 已提交
167
  // push dense parameters(not gradients) to server in sync mode
168 169
  void PushDenseParamSync(const Scope& scope,
                          const uint64_t table_id,
D
dongdaxiang 已提交
170
                          const std::vector<std::string>& var_names);
171

T
Thunderbrook 已提交
172 173 174
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names, scale_datanorm, batch_size
// Param<out>: push_sparse_status
175
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
T
Thunderbrook 已提交
176
  void PushDenseVarsAsync(
177 178
      const Scope& scope,
      const uint64_t table_id,
T
Thunderbrook 已提交
179 180
      const std::vector<std::string>& var_names,
      std::vector<::std::future<int32_t>>* push_sparse_status,
181 182 183 184
      float scale_datanorm,
      int batch_size,
      const paddle::platform::Place& place,
      gpuStream_t stream,
185
      gpuEvent_t event);
T
Thunderbrook 已提交
186 187 188
#endif
#ifdef PADDLE_WITH_XPU
  void PushDenseVarsAsync(
189 190
      const Scope& scope,
      const uint64_t table_id,
T
Thunderbrook 已提交
191 192
      const std::vector<std::string>& var_names,
      std::vector<::std::future<int32_t>>* push_sparse_status,
193 194
      float scale_datanorm,
      int batch_size,
T
Thunderbrook 已提交
195
      const paddle::platform::Place& place);
T
Thunderbrook 已提交
196
#endif
197
  void PushDenseVarsAsync(
198 199
      const Scope& scope,
      const uint64_t table_id,
200
      const std::vector<std::string>& var_names,
201
      std::vector<::std::future<int32_t>>* push_sparse_status,
202 203
      float scale_datanorm,
      int batch_size);
204

X
xujiaqi01 已提交
205
  // push dense variables to server in sync mode
206 207
  void PushDenseVarsSync(Scope* scope,
                         const uint64_t table_id,
D
dongdaxiang 已提交
208 209
                         const std::vector<std::string>& var_names);

X
xujiaqi01 已提交
210
  // Push sparse variables with labels to server in async mode
211 212
  std::vector<std::unordered_map<uint64_t, std::vector<float>>> local_tables_;
  void PullSparseToLocal(const uint64_t table_id, int fea_value_dim);
213 214
  void PullSparseVarsFromLocal(const Scope& scope,
                               const uint64_t table_id,
215 216 217 218 219 220 221 222 223
                               const std::vector<std::string>& var_names,
                               std::vector<uint64_t>* fea_keys,
                               std::vector<std::vector<float>>* fea_values,
                               int fea_value_dim);
  void ClearLocalTable();
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
  GetLocalTable() {
    return local_tables_;
  }
224

225
  // This is specially designed for click/show stats in server
X
xujiaqi01 已提交
226 227
  // Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
  //            sparse_grad_names, batch_size, use_cvm, dump_slot
228 229
  // Param<out>: push_values, push_sparse_status
  void PushSparseVarsWithLabelAsync(
230 231
      const Scope& scope,
      const uint64_t table_id,
232 233 234
      const std::vector<uint64_t>& fea_keys,
      const std::vector<float>& fea_labels,
      const std::vector<std::string>& sparse_key_names,
235 236
      const std::vector<std::string>& sparse_grad_names,
      const int emb_dim,
237
      std::vector<std::vector<float>>* push_values,
238
      std::vector<::std::future<int32_t>>* push_sparse_status,
239 240 241 242 243
      const int batch_size,
      const bool use_cvm,
      const bool dump_slot,
      std::vector<uint64_t>* sparse_push_keys,
      const bool no_cvm,
244
      const bool scale_sparse_gradient_with_batch_size);
245

246 247
  // Push sparse variables to server in async mode
  void PushSparseFromTensorWithLabelAsync(
248 249 250 251 252 253 254 255
      const Scope& scope,
      const uint64_t table_id,
      int fea_dim,
      uint64_t padding_id,
      bool scale_sparse,
      const std::string& accesor,
      const std::string& click_name,
      platform::Place place,
256 257 258 259
      const std::vector<std::string>& input_names,
      std::vector<const LoDTensor*>* inputs,    // NOLINT
      std::vector<const LoDTensor*>* outputs);  // NOLINT

260 261 262 263 264 265 266 267 268 269 270 271 272
  // Push sparse variables to server in Async mode
  // Param<In>: scope, table_id, fea_keys, sparse_grad_names
  // Param<Out>: push_values, push_sparse_status
  /*
  void PushSparseVarsAsync(
          const Scope& scope,
          const uint64_t table_id,
          const std::vector<uint64_t>& fea_keys,
          const std::vector<std::string>& sparse_grad_names,
          std::vector<std::vector<float>>* push_values,
          std::vector<::std::future<int32_t>>* push_sparse_status);
  */

X
xujiaqi01 已提交
273
  // init server
274
  void InitServer(const std::string& dist_desc, int index);
X
xujiaqi01 已提交
275
  // init trainer
276
  void InitWorker(const std::string& dist_desc,
277 278
                  const std::vector<uint64_t>& host_sign_list,
                  int node_num,
279
                  int index);
X
xujiaqi01 已提交
280
  // stop server
281
  void StopServer();
282 283
  // finalize worker to make worker can be stop
  void FinalizeWorker();
X
xujiaqi01 已提交
284
  // run server
285
  uint64_t RunServer();
286 287
  // run server with ip port
  uint64_t RunServer(const std::string& ip, uint32_t port);
X
xujiaqi01 已提交
288
  // gather server ip
289
  void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
X
xjqbest 已提交
290
  // gather client ip
X
xjqbest 已提交
291
  void GatherClients(const std::vector<uint64_t>& host_sign_list);
X
xjqbest 已提交
292
  // get client info
X
xjqbest 已提交
293
  std::vector<uint64_t> GetClientsInfo();
X
xjqbest 已提交
294
  // create client to client connection
X
xjqbest 已提交
295
  void CreateClient2ClientConnection();
296 297
  // flush all push requests
  void ClientFlush();
298
  // load from paddle model
299 300
  void LoadFromPaddleModel(Scope& scope,
                           const uint64_t table_id,  // NOLINT
301
                           std::vector<std::string> var_list,
302 303
                           std::string model_path,
                           std::string model_proto_file,
304
                           std::vector<std::string> table_var_list,
305
                           bool load_combine);
306 307

  void PrintTableStat(const uint64_t table_id);
308
  void SetFileNumOneShard(const uint64_t table_id, int file_num);
309
  // mode = 0, load all feature
X
xujiaqi01 已提交
310
  // mode = 1, load delta feature, which means load diff
311
  void LoadModel(const std::string& path, const int mode);
312
  // mode = 0, load all feature
X
xujiaqi01 已提交
313
  // mode = 1, load delta feature, which means load diff
314 315
  void LoadModelOneTable(const uint64_t table_id,
                         const std::string& path,
316
                         const int mode);
317 318 319
  // mode = 0, save all feature
  // mode = 1, save delta feature, which means save diff
  void SaveModel(const std::string& path, const int mode);
320
  void SaveMultiTableOnePath(const std::vector<int>& table_ids,
321 322
                             const std::string& path,
                             const int mode);
X
xujiaqi01 已提交
323 324
  // mode = 0, save all feature
  // mode = 1, save delta feature, which means save diff
325 326
  void SaveModelOneTable(const uint64_t table_id,
                         const std::string& path,
X
xujiaqi01 已提交
327 328
                         const int mode);
  // save model with prefix
329 330 331 332
  void SaveModelOneTablePrefix(const uint64_t table_id,
                               const std::string& path,
                               const int mode,
                               const std::string& prefix);
X
xujiaqi01 已提交
333
  // get save cache threshold
334
  double GetCacheThreshold(int table_id);
X
xujiaqi01 已提交
335
  // shuffle cache model between servers
336 337 338
  void CacheShuffle(int table_id,
                    const std::string& path,
                    const int mode,
339
                    const double cache_threshold);
X
xujiaqi01 已提交
340 341
  // save cache model
  // cache model can speed up online predict
342
  int32_t SaveCache(int table_id, const std::string& path, const int mode);
343
  // save sparse table filtered by user-defined whitelist
344 345 346 347 348 349
  int32_t SaveWithWhitelist(int table_id,
                            const std::string& path,
                            const int mode,
                            const std::string& whitelist_path);
  void LoadWithWhitelist(const uint64_t table_id,
                         const std::string& path,
350
                         const int mode);
X
xujiaqi01 已提交
351 352 353 354 355 356 357
  // copy feasign key/value from src_table_id to dest_table_id
  int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
  // copy feasign key/value from src_table_id to dest_table_id
  int32_t CopyTableByFeasign(const uint64_t src_table_id,
                             const uint64_t dest_table_id,
                             const std::vector<uint64_t>& feasign_list);
  // clear all models, release their memory
358
  void ClearModel();
X
xujiaqi01 已提交
359 360
  // clear one table
  void ClearOneTable(const uint64_t table_id);
X
xujiaqi01 已提交
361
  // shrink sparse table
362
  void ShrinkSparseTable(int table_id);
X
xujiaqi01 已提交
363
  // shrink dense table
364 365 366 367
  void ShrinkDenseTable(int table_id,
                        Scope* scope,
                        std::vector<std::string> var_list,
                        float decay,
368
                        int emb_dim);
369

D
dongdaxiang 已提交
370
  typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
X
xujiaqi01 已提交
371
  // register client to client communication
372
  int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
X
xjqbest 已提交
373
  // send client to client message
374 375
  std::future<int32_t> SendClientToClientMsg(int msg_type,
                                             int to_client_id,
D
dongdaxiang 已提交
376
                                             const std::string& msg);
377 378 379 380
  // confirm all the updated params in the current pass
  void Confirm();
  // revert all the updated params in the current pass
  void Revert();
X
xujiaqi01 已提交
381
  // FleetWrapper singleton
382 383 384 385 386 387
  static std::shared_ptr<FleetWrapper> GetInstance() {
    if (NULL == s_instance_) {
      s_instance_.reset(new paddle::framework::FleetWrapper());
    }
    return s_instance_;
  }
388 389 390
  // this performs better than rand_r, especially large data
  std::default_random_engine& LocalRandomEngine();

391 392
  void SetDate(const uint64_t table_id, const std::string& date);

393 394 395 396
#ifdef PADDLE_WITH_PSLIB
  static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif

397 398
 private:
  static std::shared_ptr<FleetWrapper> s_instance_;
X
xjqbest 已提交
399
#ifdef PADDLE_WITH_PSLIB
X
xujiaqi01 已提交
400
  std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
X
xjqbest 已提交
401
#endif
402

403 404 405
  size_t GetAbsoluteSum(size_t start,
                        size_t end,
                        size_t level,
406 407
                        const framework::LoD& lod);

408
 protected:
409
  static bool is_initialized_;
410
  bool scale_sparse_gradient_with_batch_size_;
411
  int32_t sleep_seconds_before_fail_exit_;
412 413 414
  int client2client_request_timeout_ms_;
  int client2client_connect_timeout_ms_;
  int client2client_max_retry_;
415 416 417 418
  std::unique_ptr<::ThreadPool> local_pull_pool_{nullptr};
  int pull_local_thread_num_;
  std::unique_ptr<::ThreadPool> pull_to_local_pool_{nullptr};
  int local_table_shard_num_;
419 420 421 422 423
  DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};

}  // end namespace framework
}  // end namespace paddle