fleet_wrapper.h 16.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <memory>
#ifdef PADDLE_WITH_PSLIB
19
#include <archive.h>
D
dongdaxiang 已提交
20
#include <pslib.h>
21
#endif
22
#include <ThreadPool.h>
23

24
#include <atomic>
X
xujiaqi01 已提交
25
#include <ctime>
D
dongdaxiang 已提交
26
#include <map>
27
#include <mutex>
D
dongdaxiang 已提交
28
#include <random>
29
#include <string>
30
#include <unordered_map>
31
#include <vector>
32

T
Thunderbrook 已提交
33
#include "paddle/fluid/framework/heter_util.h"
D
dongdaxiang 已提交
34
#include "paddle/fluid/framework/program_desc.h"
35
#include "paddle/fluid/framework/scope.h"
36
#include "paddle/fluid/framework/tensor.h"
37 38
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
T
Thunderbrook 已提交
39
#ifdef PADDLE_WITH_HETERPS
40
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
T
Thunderbrook 已提交
41
#endif
42
#include "paddle/fluid/framework/fleet/heter_ps/log_patch.h"
43

W
wanghuancoder 已提交
44 45 46 47 48 49
namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
namespace paddle {
namespace framework {

// A wrapper class for pslib.h, this class follows Singleton pattern
// i.e. only initialized once in the current process
// Example:
//    std::shared_ptr<FleetWrapper> fleet_ptr =
//         FleetWrapper::GetInstance();
//    string dist_desc;
//    fleet_ptr->InitServer(dist_desc, 0);
// interface design principles:
// Pull
//   Sync: PullSparseVarsSync
//   Async: PullSparseVarsAsync(not implemented currently)
// Push
//   Sync: PushSparseVarsSync
66 67
//   Async: PushSparseVarsAsync(not implemented currently)
//   Async: PushSparseVarsWithLabelAsync(with special usage)
68 69 70 71 72 73 74
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status

class FleetWrapper {
 public:
  virtual ~FleetWrapper() {}
75 76 77 78
  FleetWrapper() {
    scale_sparse_gradient_with_batch_size_ = true;
    // trainer sleep some time for pslib core dump
    sleep_seconds_before_fail_exit_ = 300;
79 80 81 82 83 84
    // pslib request server timeout ms
    client2client_request_timeout_ms_ = 500000;
    // pslib connect server timeout_ms
    client2client_connect_timeout_ms_ = 10000;
    // pslib request max retry
    client2client_max_retry_ = 3;
85
    pull_local_thread_num_ = 25;
86
  }
87

X
xujiaqi01 已提交
88
  // set client to client communication config
89 90
  void SetClient2ClientConfig(int request_timeout_ms,
                              int connect_timeout_ms,
91 92
                              int max_retry);

93 94 95
  void SetPullLocalThreadNum(int thread_num) {
    pull_local_thread_num_ = thread_num;
  }
96

T
Thunderbrook 已提交
97
#ifdef PADDLE_WITH_PSLIB
98 99
  void HeterPullSparseVars(int workerid,
                           std::shared_ptr<HeterTask> task,
T
Thunderbrook 已提交
100 101 102 103 104 105
                           const uint64_t table_id,
                           const std::vector<std::string>& var_names,
                           int fea_dim,
                           const std::vector<std::string>& var_emb_names);

  void HeterPushSparseVars(
106 107 108 109 110 111
      std::shared_ptr<HeterTask> task,
      const Scope& scope,
      const uint64_t table_id,
      const std::vector<std::string>& sparse_key_names,
      const std::vector<std::string>& sparse_grad_names,
      const int emb_dim,
T
Thunderbrook 已提交
112
      std::vector<::std::future<int32_t>>* push_sparse_status,
113 114 115
      const bool use_cvm,
      const bool dump_slot,
      const bool no_cvm);
T
Thunderbrook 已提交
116 117 118 119 120
#endif

  typedef std::function<void(int, int)> HeterCallBackFunc;
  int RegisterHeterCallback(HeterCallBackFunc handler);

X
xujiaqi01 已提交
121
  // Pull sparse variables from server in sync mode
122
  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
123
  // Param<out>: fea_values
124 125
  void PullSparseVarsSync(const Scope& scope,
                          const uint64_t table_id,
126 127 128
                          const std::vector<std::string>& var_names,
                          std::vector<uint64_t>* fea_keys,
                          std::vector<std::vector<float>>* fea_values,
129 130
                          int fea_dim,
                          const std::vector<std::string>& var_emb_names);
131 132 133 134

  // Pull sparse variables from server in async mode
  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim
  // Param<out>: fea_values std::future
135
  std::future<int32_t> PullSparseVarsAsync(
136 137
      const Scope& scope,
      const uint64_t table_id,
138 139
      const std::vector<std::string>& var_names,
      std::vector<uint64_t>* fea_keys,
140 141
      std::vector<std::vector<float>>* fea_values,
      int fea_dim);
142 143 144

  // Pull sparse variables from server in sync mode
  // pull immediately to tensors
145 146 147 148 149 150 151
  void PullSparseToTensorSync(
      const uint64_t table_id,
      int fea_dim,
      uint64_t padding_id,
      platform::Place place,
      std::vector<const phi::DenseTensor*>* inputs,  // NOLINT
      std::vector<phi::DenseTensor*>* outputs);      // NOLINT
152

X
xujiaqi01 已提交
153
  // pull dense variables from server in sync mod
154 155
  // Param<in>: scope, table_id, var_names
  // Param<out>: void
156 157
  void PullDenseVarsSync(const Scope& scope,
                         const uint64_t table_id,
158 159
                         const std::vector<std::string>& var_names);

X
xujiaqi01 已提交
160 161 162
  // pull dense variables from server in async mod
  // Param<in>: scope, table_id, var_names
  // Param<out>: pull_dense_status
163
  void PullDenseVarsAsync(
164 165
      const Scope& scope,
      const uint64_t table_id,
166
      const std::vector<std::string>& var_names,
167 168
      std::vector<::std::future<int32_t>>* pull_dense_status,
      bool in_cpu);
169

X
xujiaqi01 已提交
170
  // push dense parameters(not gradients) to server in sync mode
171 172
  void PushDenseParamSync(const Scope& scope,
                          const uint64_t table_id,
D
dongdaxiang 已提交
173
                          const std::vector<std::string>& var_names);
174

T
Thunderbrook 已提交
175 176 177
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names, scale_datanorm, batch_size
// Param<out>: push_sparse_status
178
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
T
Thunderbrook 已提交
179
  void PushDenseVarsAsync(
180 181
      const Scope& scope,
      const uint64_t table_id,
T
Thunderbrook 已提交
182 183
      const std::vector<std::string>& var_names,
      std::vector<::std::future<int32_t>>* push_sparse_status,
184 185 186 187
      float scale_datanorm,
      int batch_size,
      const paddle::platform::Place& place,
      gpuStream_t stream,
188
      gpuEvent_t event);
T
Thunderbrook 已提交
189 190 191
#endif
#ifdef PADDLE_WITH_XPU
  void PushDenseVarsAsync(
192 193
      const Scope& scope,
      const uint64_t table_id,
T
Thunderbrook 已提交
194 195
      const std::vector<std::string>& var_names,
      std::vector<::std::future<int32_t>>* push_sparse_status,
196 197
      float scale_datanorm,
      int batch_size,
T
Thunderbrook 已提交
198
      const paddle::platform::Place& place);
T
Thunderbrook 已提交
199
#endif
200
  void PushDenseVarsAsync(
201 202
      const Scope& scope,
      const uint64_t table_id,
203
      const std::vector<std::string>& var_names,
204
      std::vector<::std::future<int32_t>>* push_sparse_status,
205 206
      float scale_datanorm,
      int batch_size);
207

X
xujiaqi01 已提交
208
  // push dense variables to server in sync mode
209 210
  void PushDenseVarsSync(Scope* scope,
                         const uint64_t table_id,
D
dongdaxiang 已提交
211 212
                         const std::vector<std::string>& var_names);

X
xujiaqi01 已提交
213
  // Push sparse variables with labels to server in async mode
214 215
  std::vector<std::unordered_map<uint64_t, std::vector<float>>> local_tables_;
  void PullSparseToLocal(const uint64_t table_id, int fea_value_dim);
216 217
  void PullSparseVarsFromLocal(const Scope& scope,
                               const uint64_t table_id,
218 219 220 221 222 223 224 225 226
                               const std::vector<std::string>& var_names,
                               std::vector<uint64_t>* fea_keys,
                               std::vector<std::vector<float>>* fea_values,
                               int fea_value_dim);
  void ClearLocalTable();
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
  GetLocalTable() {
    return local_tables_;
  }
227

228
  // This is specially designed for click/show stats in server
X
xujiaqi01 已提交
229 230
  // Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
  //            sparse_grad_names, batch_size, use_cvm, dump_slot
231 232
  // Param<out>: push_values, push_sparse_status
  void PushSparseVarsWithLabelAsync(
233 234
      const Scope& scope,
      const uint64_t table_id,
235 236 237
      const std::vector<uint64_t>& fea_keys,
      const std::vector<float>& fea_labels,
      const std::vector<std::string>& sparse_key_names,
238 239
      const std::vector<std::string>& sparse_grad_names,
      const int emb_dim,
240
      std::vector<std::vector<float>>* push_values,
241
      std::vector<::std::future<int32_t>>* push_sparse_status,
242 243 244 245 246
      const int batch_size,
      const bool use_cvm,
      const bool dump_slot,
      std::vector<uint64_t>* sparse_push_keys,
      const bool no_cvm,
247
      const bool scale_sparse_gradient_with_batch_size);
248

249 250
  // Push sparse variables to server in async mode
  void PushSparseFromTensorWithLabelAsync(
251 252 253 254 255 256 257 258
      const Scope& scope,
      const uint64_t table_id,
      int fea_dim,
      uint64_t padding_id,
      bool scale_sparse,
      const std::string& accesor,
      const std::string& click_name,
      platform::Place place,
259
      const std::vector<std::string>& input_names,
260 261
      std::vector<const phi::DenseTensor*>* inputs,    // NOLINT
      std::vector<const phi::DenseTensor*>* outputs);  // NOLINT
262

263 264 265 266 267 268 269 270 271 272 273 274 275
  // Push sparse variables to server in Async mode
  // Param<In>: scope, table_id, fea_keys, sparse_grad_names
  // Param<Out>: push_values, push_sparse_status
  /*
  void PushSparseVarsAsync(
          const Scope& scope,
          const uint64_t table_id,
          const std::vector<uint64_t>& fea_keys,
          const std::vector<std::string>& sparse_grad_names,
          std::vector<std::vector<float>>* push_values,
          std::vector<::std::future<int32_t>>* push_sparse_status);
  */

X
xujiaqi01 已提交
276
  // init server
277
  void InitServer(const std::string& dist_desc, int index);
X
xujiaqi01 已提交
278
  // init trainer
279
  void InitWorker(const std::string& dist_desc,
280 281
                  const std::vector<uint64_t>& host_sign_list,
                  int node_num,
282
                  int index);
X
xujiaqi01 已提交
283
  // stop server
284
  void StopServer();
285 286
  // finalize worker to make worker can be stop
  void FinalizeWorker();
X
xujiaqi01 已提交
287
  // run server
288
  uint64_t RunServer();
289 290
  // run server with ip port
  uint64_t RunServer(const std::string& ip, uint32_t port);
X
xujiaqi01 已提交
291
  // gather server ip
292
  void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
X
xjqbest 已提交
293
  // gather client ip
X
xjqbest 已提交
294
  void GatherClients(const std::vector<uint64_t>& host_sign_list);
X
xjqbest 已提交
295
  // get client info
X
xjqbest 已提交
296
  std::vector<uint64_t> GetClientsInfo();
X
xjqbest 已提交
297
  // create client to client connection
X
xjqbest 已提交
298
  void CreateClient2ClientConnection();
299 300
  // flush all push requests
  void ClientFlush();
301
  // load from paddle model
302
  void LoadFromPaddleModel(Scope& scope,             // NOLINT
303
                           const uint64_t table_id,  // NOLINT
304
                           std::vector<std::string> var_list,
305 306
                           std::string model_path,
                           std::string model_proto_file,
307
                           std::vector<std::string> table_var_list,
308
                           bool load_combine);
309

P
pangengzheng 已提交
310
  void PrintTableStat(uint64_t table_id, uint32_t pass_id, size_t threshold);
311
  void SetFileNumOneShard(const uint64_t table_id, int file_num);
312
  // mode = 0, load all feature
X
xujiaqi01 已提交
313
  // mode = 1, load delta feature, which means load diff
314
  void LoadModel(const std::string& path, const int mode);
315
  // mode = 0, load all feature
X
xujiaqi01 已提交
316
  // mode = 1, load delta feature, which means load diff
317 318
  void LoadModelOneTable(const uint64_t table_id,
                         const std::string& path,
319
                         const int mode);
320 321 322
  // mode = 0, save all feature
  // mode = 1, save delta feature, which means save diff
  void SaveModel(const std::string& path, const int mode);
323
  void SaveMultiTableOnePath(const std::vector<int>& table_ids,
324 325
                             const std::string& path,
                             const int mode);
X
xujiaqi01 已提交
326 327
  // mode = 0, save all feature
  // mode = 1, save delta feature, which means save diff
328 329
  void SaveModelOneTable(const uint64_t table_id,
                         const std::string& path,
X
xujiaqi01 已提交
330 331
                         const int mode);
  // save model with prefix
332 333 334 335
  void SaveModelOneTablePrefix(const uint64_t table_id,
                               const std::string& path,
                               const int mode,
                               const std::string& prefix);
X
xujiaqi01 已提交
336
  // get save cache threshold
337
  double GetCacheThreshold(int table_id);
X
xujiaqi01 已提交
338
  // shuffle cache model between servers
339 340 341
  void CacheShuffle(int table_id,
                    const std::string& path,
                    const int mode,
342
                    const double cache_threshold);
X
xujiaqi01 已提交
343 344
  // save cache model
  // cache model can speed up online predict
345
  int32_t SaveCache(int table_id, const std::string& path, const int mode);
346
  // save sparse table filtered by user-defined whitelist
347 348 349 350 351 352
  int32_t SaveWithWhitelist(int table_id,
                            const std::string& path,
                            const int mode,
                            const std::string& whitelist_path);
  void LoadWithWhitelist(const uint64_t table_id,
                         const std::string& path,
353
                         const int mode);
X
xujiaqi01 已提交
354 355 356 357 358 359 360
  // copy feasign key/value from src_table_id to dest_table_id
  int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
  // copy feasign key/value from src_table_id to dest_table_id
  int32_t CopyTableByFeasign(const uint64_t src_table_id,
                             const uint64_t dest_table_id,
                             const std::vector<uint64_t>& feasign_list);
  // clear all models, release their memory
361
  void ClearModel();
X
xujiaqi01 已提交
362 363
  // clear one table
  void ClearOneTable(const uint64_t table_id);
X
xujiaqi01 已提交
364
  // shrink sparse table
365
  void ShrinkSparseTable(int table_id);
X
xujiaqi01 已提交
366
  // shrink dense table
367 368 369 370
  void ShrinkDenseTable(int table_id,
                        Scope* scope,
                        std::vector<std::string> var_list,
                        float decay,
371
                        int emb_dim);
372

D
dongdaxiang 已提交
373
  typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
X
xujiaqi01 已提交
374
  // register client to client communication
375
  int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
X
xjqbest 已提交
376
  // send client to client message
377 378
  std::future<int32_t> SendClientToClientMsg(int msg_type,
                                             int to_client_id,
D
dongdaxiang 已提交
379
                                             const std::string& msg);
380 381 382 383
  // confirm all the updated params in the current pass
  void Confirm();
  // revert all the updated params in the current pass
  void Revert();
P
pangengzheng 已提交
384 385 386 387 388 389 390

  std::string GetDistDesc() const {
    CHECK(is_initialized_ == true)
        << "fleetwrapper should be initialized first!!!";
    return dist_desc_;
  }

X
xujiaqi01 已提交
391
  // FleetWrapper singleton
392
  static std::shared_ptr<FleetWrapper> GetInstance() {
393 394 395 396 397
    {
      std::lock_guard<std::mutex> lk(ins_mutex);
      if (NULL == s_instance_) {
        s_instance_.reset(new paddle::framework::FleetWrapper());
      }
398 399 400
    }
    return s_instance_;
  }
401 402 403
  // this performs better than rand_r, especially large data
  std::default_random_engine& LocalRandomEngine();

404 405
  void SetDate(const uint64_t table_id, const std::string& date);

406 407 408 409
#ifdef PADDLE_WITH_PSLIB
  static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif

410 411
 private:
  static std::shared_ptr<FleetWrapper> s_instance_;
P
pangengzheng 已提交
412
  std::string dist_desc_;
413
  static std::mutex ins_mutex;
X
xjqbest 已提交
414
#ifdef PADDLE_WITH_PSLIB
X
xujiaqi01 已提交
415
  std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
X
xjqbest 已提交
416
#endif
417

418 419 420
  size_t GetAbsoluteSum(size_t start,
                        size_t end,
                        size_t level,
421 422
                        const framework::LoD& lod);

423
 protected:
424
  static bool is_initialized_;
425
  bool scale_sparse_gradient_with_batch_size_;
426
  int32_t sleep_seconds_before_fail_exit_;
427 428 429
  int client2client_request_timeout_ms_;
  int client2client_connect_timeout_ms_;
  int client2client_max_retry_;
430 431 432 433
  std::unique_ptr<::ThreadPool> local_pull_pool_{nullptr};
  int pull_local_thread_num_;
  std::unique_ptr<::ThreadPool> pull_to_local_pool_{nullptr};
  int local_table_shard_num_;
434 435 436 437 438
  DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};

}  // end namespace framework
}  // end namespace paddle