fleet_wrapper.h 16.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <memory>
#ifdef PADDLE_WITH_PSLIB
19
#include <archive.h>
D
dongdaxiang 已提交
20
#include <pslib.h>
21
#endif
22
#include <ThreadPool.h>
23

24
#include <atomic>
X
xujiaqi01 已提交
25
#include <ctime>
D
dongdaxiang 已提交
26
#include <map>
27
#include <mutex>
D
dongdaxiang 已提交
28
#include <random>
29
#include <string>
30
#include <unordered_map>
31
#include <vector>
32

T
Thunderbrook 已提交
33
#include "paddle/fluid/framework/heter_util.h"
D
dongdaxiang 已提交
34
#include "paddle/fluid/framework/program_desc.h"
35
#include "paddle/fluid/framework/scope.h"
36
#include "paddle/fluid/framework/tensor.h"
37 38
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
T
Thunderbrook 已提交
39
#ifdef PADDLE_WITH_HETERPS
40
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
T
Thunderbrook 已提交
41
#endif
42
#include "paddle/fluid/framework/fleet/heter_ps/log_patch.h"
43

W
wanghuancoder 已提交
44 45 46 47 48 49
namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
namespace paddle {
namespace framework {

// A wrapper class for pslib.h, this class follows Singleton pattern
// i.e. only initialized once in the current process
// Example:
//    std::shared_ptr<FleetWrapper> fleet_ptr =
//         FleetWrapper::GetInstance();
//    string dist_desc;
//    fleet_ptr->InitServer(dist_desc, 0);
// interface design principles:
// Pull
//   Sync: PullSparseVarsSync
//   Async: PullSparseVarsAsync(not implemented currently)
// Push
//   Sync: PushSparseVarsSync
66 67
//   Async: PushSparseVarsAsync(not implemented currently)
//   Async: PushSparseVarsWithLabelAsync(with special usage)
68 69 70 71 72 73 74
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status

class FleetWrapper {
 public:
  virtual ~FleetWrapper() {}
75 76 77 78
  FleetWrapper() {
    scale_sparse_gradient_with_batch_size_ = true;
    // trainer sleep some time for pslib core dump
    sleep_seconds_before_fail_exit_ = 300;
79 80 81 82 83 84
    // pslib request server timeout ms
    client2client_request_timeout_ms_ = 500000;
    // pslib connect server timeout_ms
    client2client_connect_timeout_ms_ = 10000;
    // pslib request max retry
    client2client_max_retry_ = 3;
85
    pull_local_thread_num_ = 25;
86
  }
87

X
xujiaqi01 已提交
88
  // set client to client communication config
89 90
  void SetClient2ClientConfig(int request_timeout_ms,
                              int connect_timeout_ms,
91 92
                              int max_retry);

93 94 95
  void SetPullLocalThreadNum(int thread_num) {
    pull_local_thread_num_ = thread_num;
  }
96

T
Thunderbrook 已提交
97
#ifdef PADDLE_WITH_PSLIB
98 99
  void HeterPullSparseVars(int workerid,
                           std::shared_ptr<HeterTask> task,
T
Thunderbrook 已提交
100 101 102 103 104 105
                           const uint64_t table_id,
                           const std::vector<std::string>& var_names,
                           int fea_dim,
                           const std::vector<std::string>& var_emb_names);

  void HeterPushSparseVars(
106 107 108 109 110 111
      std::shared_ptr<HeterTask> task,
      const Scope& scope,
      const uint64_t table_id,
      const std::vector<std::string>& sparse_key_names,
      const std::vector<std::string>& sparse_grad_names,
      const int emb_dim,
T
Thunderbrook 已提交
112
      std::vector<::std::future<int32_t>>* push_sparse_status,
113 114 115
      const bool use_cvm,
      const bool dump_slot,
      const bool no_cvm);
T
Thunderbrook 已提交
116 117 118 119 120
#endif

  typedef std::function<void(int, int)> HeterCallBackFunc;
  int RegisterHeterCallback(HeterCallBackFunc handler);

X
xujiaqi01 已提交
121
  // Pull sparse variables from server in sync mode
122
  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
123
  // Param<out>: fea_values
124 125
  void PullSparseVarsSync(const Scope& scope,
                          const uint64_t table_id,
126 127 128
                          const std::vector<std::string>& var_names,
                          std::vector<uint64_t>* fea_keys,
                          std::vector<std::vector<float>>* fea_values,
129 130
                          int fea_dim,
                          const std::vector<std::string>& var_emb_names);
131 132 133 134

  // Pull sparse variables from server in async mode
  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim
  // Param<out>: fea_values std::future
135
  std::future<int32_t> PullSparseVarsAsync(
136 137
      const Scope& scope,
      const uint64_t table_id,
138 139
      const std::vector<std::string>& var_names,
      std::vector<uint64_t>* fea_keys,
140 141
      std::vector<std::vector<float>>* fea_values,
      int fea_dim);
142 143 144

  // Pull sparse variables from server in sync mode
  // pull immediately to tensors
145 146 147 148
  void PullSparseToTensorSync(const uint64_t table_id,
                              int fea_dim,
                              uint64_t padding_id,
                              platform::Place place,
149 150 151
                              std::vector<const LoDTensor*>* inputs,  // NOLINT
                              std::vector<LoDTensor*>* outputs);      // NOLINT

X
xujiaqi01 已提交
152
  // pull dense variables from server in sync mod
153 154
  // Param<in>: scope, table_id, var_names
  // Param<out>: void
155 156
  void PullDenseVarsSync(const Scope& scope,
                         const uint64_t table_id,
157 158
                         const std::vector<std::string>& var_names);

X
xujiaqi01 已提交
159 160 161
  // pull dense variables from server in async mod
  // Param<in>: scope, table_id, var_names
  // Param<out>: pull_dense_status
162
  void PullDenseVarsAsync(
163 164
      const Scope& scope,
      const uint64_t table_id,
165
      const std::vector<std::string>& var_names,
166 167
      std::vector<::std::future<int32_t>>* pull_dense_status,
      bool in_cpu);
168

X
xujiaqi01 已提交
169
  // push dense parameters(not gradients) to server in sync mode
170 171
  void PushDenseParamSync(const Scope& scope,
                          const uint64_t table_id,
D
dongdaxiang 已提交
172
                          const std::vector<std::string>& var_names);
173

T
Thunderbrook 已提交
174 175 176
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names, scale_datanorm, batch_size
// Param<out>: push_sparse_status
177
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
T
Thunderbrook 已提交
178
  void PushDenseVarsAsync(
179 180
      const Scope& scope,
      const uint64_t table_id,
T
Thunderbrook 已提交
181 182
      const std::vector<std::string>& var_names,
      std::vector<::std::future<int32_t>>* push_sparse_status,
183 184 185 186
      float scale_datanorm,
      int batch_size,
      const paddle::platform::Place& place,
      gpuStream_t stream,
187
      gpuEvent_t event);
T
Thunderbrook 已提交
188 189 190
#endif
#ifdef PADDLE_WITH_XPU
  void PushDenseVarsAsync(
191 192
      const Scope& scope,
      const uint64_t table_id,
T
Thunderbrook 已提交
193 194
      const std::vector<std::string>& var_names,
      std::vector<::std::future<int32_t>>* push_sparse_status,
195 196
      float scale_datanorm,
      int batch_size,
T
Thunderbrook 已提交
197
      const paddle::platform::Place& place);
T
Thunderbrook 已提交
198
#endif
199
  void PushDenseVarsAsync(
200 201
      const Scope& scope,
      const uint64_t table_id,
202
      const std::vector<std::string>& var_names,
203
      std::vector<::std::future<int32_t>>* push_sparse_status,
204 205
      float scale_datanorm,
      int batch_size);
206

X
xujiaqi01 已提交
207
  // push dense variables to server in sync mode
208 209
  void PushDenseVarsSync(Scope* scope,
                         const uint64_t table_id,
D
dongdaxiang 已提交
210 211
                         const std::vector<std::string>& var_names);

X
xujiaqi01 已提交
212
  // Push sparse variables with labels to server in async mode
213 214
  std::vector<std::unordered_map<uint64_t, std::vector<float>>> local_tables_;
  void PullSparseToLocal(const uint64_t table_id, int fea_value_dim);
215 216
  void PullSparseVarsFromLocal(const Scope& scope,
                               const uint64_t table_id,
217 218 219 220 221 222 223 224 225
                               const std::vector<std::string>& var_names,
                               std::vector<uint64_t>* fea_keys,
                               std::vector<std::vector<float>>* fea_values,
                               int fea_value_dim);
  void ClearLocalTable();
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
  GetLocalTable() {
    return local_tables_;
  }
226

227
  // This is specially designed for click/show stats in server
X
xujiaqi01 已提交
228 229
  // Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
  //            sparse_grad_names, batch_size, use_cvm, dump_slot
230 231
  // Param<out>: push_values, push_sparse_status
  void PushSparseVarsWithLabelAsync(
232 233
      const Scope& scope,
      const uint64_t table_id,
234 235 236
      const std::vector<uint64_t>& fea_keys,
      const std::vector<float>& fea_labels,
      const std::vector<std::string>& sparse_key_names,
237 238
      const std::vector<std::string>& sparse_grad_names,
      const int emb_dim,
239
      std::vector<std::vector<float>>* push_values,
240
      std::vector<::std::future<int32_t>>* push_sparse_status,
241 242 243 244 245
      const int batch_size,
      const bool use_cvm,
      const bool dump_slot,
      std::vector<uint64_t>* sparse_push_keys,
      const bool no_cvm,
246
      const bool scale_sparse_gradient_with_batch_size);
247

248 249
  // Push sparse variables to server in async mode
  void PushSparseFromTensorWithLabelAsync(
250 251 252 253 254 255 256 257
      const Scope& scope,
      const uint64_t table_id,
      int fea_dim,
      uint64_t padding_id,
      bool scale_sparse,
      const std::string& accesor,
      const std::string& click_name,
      platform::Place place,
258 259 260 261
      const std::vector<std::string>& input_names,
      std::vector<const LoDTensor*>* inputs,    // NOLINT
      std::vector<const LoDTensor*>* outputs);  // NOLINT

262 263 264 265 266 267 268 269 270 271 272 273 274
  // Push sparse variables to server in Async mode
  // Param<In>: scope, table_id, fea_keys, sparse_grad_names
  // Param<Out>: push_values, push_sparse_status
  /*
  void PushSparseVarsAsync(
          const Scope& scope,
          const uint64_t table_id,
          const std::vector<uint64_t>& fea_keys,
          const std::vector<std::string>& sparse_grad_names,
          std::vector<std::vector<float>>* push_values,
          std::vector<::std::future<int32_t>>* push_sparse_status);
  */

X
xujiaqi01 已提交
275
  // init server
276
  void InitServer(const std::string& dist_desc, int index);
X
xujiaqi01 已提交
277
  // init trainer
278
  void InitWorker(const std::string& dist_desc,
279 280
                  const std::vector<uint64_t>& host_sign_list,
                  int node_num,
281
                  int index);
X
xujiaqi01 已提交
282
  // stop server
283
  void StopServer();
284 285
  // finalize worker to make worker can be stop
  void FinalizeWorker();
X
xujiaqi01 已提交
286
  // run server
287
  uint64_t RunServer();
288 289
  // run server with ip port
  uint64_t RunServer(const std::string& ip, uint32_t port);
X
xujiaqi01 已提交
290
  // gather server ip
291
  void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
X
xjqbest 已提交
292
  // gather client ip
X
xjqbest 已提交
293
  void GatherClients(const std::vector<uint64_t>& host_sign_list);
X
xjqbest 已提交
294
  // get client info
X
xjqbest 已提交
295
  std::vector<uint64_t> GetClientsInfo();
X
xjqbest 已提交
296
  // create client to client connection
X
xjqbest 已提交
297
  void CreateClient2ClientConnection();
298 299
  // flush all push requests
  void ClientFlush();
300
  // load from paddle model
301 302
  void LoadFromPaddleModel(Scope& scope,
                           const uint64_t table_id,  // NOLINT
303
                           std::vector<std::string> var_list,
304 305
                           std::string model_path,
                           std::string model_proto_file,
306
                           std::vector<std::string> table_var_list,
307
                           bool load_combine);
308 309

  void PrintTableStat(const uint64_t table_id);
310
  void SetFileNumOneShard(const uint64_t table_id, int file_num);
311
  // mode = 0, load all feature
X
xujiaqi01 已提交
312
  // mode = 1, load delta feature, which means load diff
313
  void LoadModel(const std::string& path, const int mode);
314
  // mode = 0, load all feature
X
xujiaqi01 已提交
315
  // mode = 1, load delta feature, which means load diff
316 317
  void LoadModelOneTable(const uint64_t table_id,
                         const std::string& path,
318
                         const int mode);
319 320 321
  // mode = 0, save all feature
  // mode = 1, save delta feature, which means save diff
  void SaveModel(const std::string& path, const int mode);
322
  void SaveMultiTableOnePath(const std::vector<int>& table_ids,
323 324
                             const std::string& path,
                             const int mode);
X
xujiaqi01 已提交
325 326
  // mode = 0, save all feature
  // mode = 1, save delta feature, which means save diff
327 328
  void SaveModelOneTable(const uint64_t table_id,
                         const std::string& path,
X
xujiaqi01 已提交
329 330
                         const int mode);
  // save model with prefix
331 332 333 334
  void SaveModelOneTablePrefix(const uint64_t table_id,
                               const std::string& path,
                               const int mode,
                               const std::string& prefix);
X
xujiaqi01 已提交
335
  // get save cache threshold
336
  double GetCacheThreshold(int table_id);
X
xujiaqi01 已提交
337
  // shuffle cache model between servers
338 339 340
  void CacheShuffle(int table_id,
                    const std::string& path,
                    const int mode,
341
                    const double cache_threshold);
X
xujiaqi01 已提交
342 343
  // save cache model
  // cache model can speed up online predict
344
  int32_t SaveCache(int table_id, const std::string& path, const int mode);
345
  // save sparse table filtered by user-defined whitelist
346 347 348 349 350 351
  int32_t SaveWithWhitelist(int table_id,
                            const std::string& path,
                            const int mode,
                            const std::string& whitelist_path);
  void LoadWithWhitelist(const uint64_t table_id,
                         const std::string& path,
352
                         const int mode);
X
xujiaqi01 已提交
353 354 355 356 357 358 359
  // copy feasign key/value from src_table_id to dest_table_id
  int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
  // copy feasign key/value from src_table_id to dest_table_id
  int32_t CopyTableByFeasign(const uint64_t src_table_id,
                             const uint64_t dest_table_id,
                             const std::vector<uint64_t>& feasign_list);
  // clear all models, release their memory
360
  void ClearModel();
X
xujiaqi01 已提交
361 362
  // clear one table
  void ClearOneTable(const uint64_t table_id);
X
xujiaqi01 已提交
363
  // shrink sparse table
364
  void ShrinkSparseTable(int table_id);
X
xujiaqi01 已提交
365
  // shrink dense table
366 367 368 369
  void ShrinkDenseTable(int table_id,
                        Scope* scope,
                        std::vector<std::string> var_list,
                        float decay,
370
                        int emb_dim);
371

D
dongdaxiang 已提交
372
  typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
X
xujiaqi01 已提交
373
  // register client to client communication
374
  int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
X
xjqbest 已提交
375
  // send client to client message
376 377
  std::future<int32_t> SendClientToClientMsg(int msg_type,
                                             int to_client_id,
D
dongdaxiang 已提交
378
                                             const std::string& msg);
379 380 381 382
  // confirm all the updated params in the current pass
  void Confirm();
  // revert all the updated params in the current pass
  void Revert();
X
xujiaqi01 已提交
383
  // FleetWrapper singleton
384
  static std::shared_ptr<FleetWrapper> GetInstance() {
385 386 387 388 389
    {
      std::lock_guard<std::mutex> lk(ins_mutex);
      if (NULL == s_instance_) {
        s_instance_.reset(new paddle::framework::FleetWrapper());
      }
390 391 392
    }
    return s_instance_;
  }
393 394 395
  // this performs better than rand_r, especially large data
  std::default_random_engine& LocalRandomEngine();

396 397
  void SetDate(const uint64_t table_id, const std::string& date);

398 399 400 401
#ifdef PADDLE_WITH_PSLIB
  static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif

402 403
 private:
  static std::shared_ptr<FleetWrapper> s_instance_;
404
  static std::mutex ins_mutex;
X
xjqbest 已提交
405
#ifdef PADDLE_WITH_PSLIB
X
xujiaqi01 已提交
406
  std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
X
xjqbest 已提交
407
#endif
408

409 410 411
  size_t GetAbsoluteSum(size_t start,
                        size_t end,
                        size_t level,
412 413
                        const framework::LoD& lod);

414
 protected:
415
  static bool is_initialized_;
416
  bool scale_sparse_gradient_with_batch_size_;
417
  int32_t sleep_seconds_before_fail_exit_;
418 419 420
  int client2client_request_timeout_ms_;
  int client2client_connect_timeout_ms_;
  int client2client_max_retry_;
421 422 423 424
  std::unique_ptr<::ThreadPool> local_pull_pool_{nullptr};
  int pull_local_thread_num_;
  std::unique_ptr<::ThreadPool> pull_to_local_pool_{nullptr};
  int local_table_shard_num_;
425 426 427 428 429
  DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};

}  // end namespace framework
}  // end namespace paddle