device_worker.h 23.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

H
hutuxian 已提交
17
#include <atomic>
18 19 20 21
#include <fstream>
#include <map>
#include <memory>
#include <mutex>  // NOLINT
Z
zhang wenhui 已提交
22
#include <set>
23
#include <string>
X
xujiaqi01 已提交
24 25 26 27
#include <thread>         // NOLINT
#include <unordered_map>  // NOLINT
#include <unordered_set>  // NOLINT
#include <utility>        // NOLINT
28 29 30
#include <vector>

#include "paddle/fluid/framework/data_feed.h"
L
lilong12 已提交
31
#include "paddle/fluid/framework/executor_gc_helper.h"
T
Thunderbrook 已提交
32
#include "paddle/fluid/framework/heter_service.h"
33 34 35 36 37 38 39 40
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/place.h"
D
dongdaxiang 已提交
41
#include "paddle/fluid/platform/port.h"
42 43
#include "paddle/fluid/platform/timer.h"

W
wanghuancoder 已提交
44 45 46 47 48 49 50 51 52 53 54 55
namespace paddle {
namespace framework {
class LoDTensor;
class ProgramDesc;
class Scope;
class Tensor;
}  // namespace framework
namespace platform {
class DeviceContext;
}  // namespace platform
}  // namespace paddle

56
#if defined(PADDLE_WITH_NCCL)
H
hutuxian 已提交
57 58 59
#include "paddle/fluid/platform/nccl_helper.h"
#endif

60 61 62
namespace paddle {
namespace framework {

63
std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end);
64 65 66
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);

67 68
class FleetWrapper;

T
Thunderbrook 已提交
69 70 71 72
#ifdef PADDLE_WITH_PSLIB
class HeterWrapper;
#endif

73 74 75 76
class PullDenseWorker {
 public:
  virtual ~PullDenseWorker() {}
  virtual void Initialize(const TrainerDesc& param);
T
Thunderbrook 已提交
77 78
#ifdef PADDLE_WITH_CUDA
  void AddStream(const cudaStream_t stream) { copy_streams_.push_back(stream); }
T
Thunderbrook 已提交
79
#endif
T
Thunderbrook 已提交
80

T
Thunderbrook 已提交
81
#if (defined PADDLE_WITH_CUDA) || (defined PADDLE_WITH_XPU)
T
Thunderbrook 已提交
82 83 84 85 86 87
  void AddPlace(const paddle::platform::Place place) {
    places_.push_back(place);
  }

  void AddThreadScope(Scope* scope) { thread_scopes_.push_back(scope); }
#endif
88 89
  int Start();
  void Stop();
90
  void SetRootScope(Scope* scope) { root_scope_ = scope; }
91 92 93
  void IncreaseThreadVersion(int thread_id, uint64_t table_id);
  void ResetThreadVersion(uint64_t table_id);
  void Wait(std::vector<::std::future<int32_t>>* status_vec);
94
  void PullDense(bool force_update = false);
T
Thunderbrook 已提交
95
  void CreatePinVar();
T
Thunderbrook 已提交
96
  void MergeDenseParam();
97 98
  int GetThreadIdByScope(const Scope* scope);
  void SetThreadIdByScope(const Scope* scope, int tid);
99 100 101 102 103 104 105
  static std::shared_ptr<PullDenseWorker> GetInstance() {
    if (NULL == s_instance_) {
      s_instance_.reset(new paddle::framework::PullDenseWorker());
    }
    return s_instance_;
  }

106 107
  static std::shared_ptr<PullDenseWorker> s_instance_;

108
 private:
109
  PullDenseWorker() : root_scope_(NULL) {}
110 111 112 113 114 115
  void Run();
  bool CheckUpdateParam(uint64_t table_id);

 private:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  PullDenseWorkerParameter param_;
H
heqiaozhi 已提交
116
  DownpourWorkerParameter dwp_param_;
117 118 119
  Scope* root_scope_;
  bool running_;

D
dongdaxiang 已提交
120 121 122 123 124
  static std::map<uint64_t, uint64_t> last_versions_;
  static std::map<uint64_t, uint64_t> current_version_;
  static std::mutex mutex_for_version_;
  static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
  static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
125 126 127 128 129 130 131 132 133 134 135 136 137 138

  std::thread t_;
  int thread_num_;
  int sleep_time_ms_;
  int threshold_;

  std::vector<::std::future<int32_t>> pull_dense_status_;
  uint32_t pull_dense_fail_times_ = 0;
  std::vector<float> base_norm_param_;
  std::vector<float> mean_;
  std::vector<float> scale_;
  float squared_sum_epsilon_ = 1e-4;
  std::mutex mutex_for_mean_scale_;
  float total_batch_num_ = 0;
139
  std::unordered_map<const Scope*, int> scope_to_thread_id_;
T
Thunderbrook 已提交
140 141 142

#ifdef PADDLE_WITH_CUDA
  std::vector<cudaStream_t> copy_streams_;
T
Thunderbrook 已提交
143
#endif
T
Thunderbrook 已提交
144 145
  std::vector<paddle::platform::Place> places_;
  std::vector<Scope*> thread_scopes_;
146 147 148 149 150
};

// should incorporate different type of device
class DeviceWorker {
 public:
151 152 153 154
  DeviceWorker() {
    no_cvm_ = true;
    use_cvm_ = false;
  }
155 156
  virtual ~DeviceWorker() {}
  virtual void Initialize(const TrainerDesc& desc) = 0;
H
hutuxian 已提交
157
  virtual void InitRandomDumpConfig(const TrainerDesc& desc);
158 159
  virtual void SetDeviceIndex(int tid) = 0;
  virtual void TrainFiles() = 0;
D
dongdaxiang 已提交
160
  virtual void PrintFetchVars() = 0;
161 162 163 164 165
  virtual void TrainFilesWithProfiler() = 0;
  virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
  // will make this zero copy in the future
  virtual void BindingDataFeedMemory() = 0;
  virtual void SetRootScope(Scope* root_scope);
J
jiaqi 已提交
166
  virtual void SetDataFeed(DataFeed* data_feed);
T
Thunderbrook 已提交
167 168
  virtual void SetWorkerNum(int num) {}
  virtual void CacheProgram(const ProgramDesc& main_program) {}
T
Thunderbrook 已提交
169
  virtual void ProduceTasks() {}
T
Thunderbrook 已提交
170
  virtual void GetXpuOpIndex() {}
T
Thunderbrook 已提交
171 172 173 174
#ifdef PADDLE_WITH_CUDA
  virtual void SetStream(const cudaStream_t stream) {}
  virtual void SetEvent(const cudaEvent_t event) {}
#endif
H
hutuxian 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
  virtual void SetNeedDumpField(bool need_dump_field) {
    need_dump_field_ = need_dump_field;
  }
  virtual void SetNeedDumpParam(bool need_dump_param) {
    need_dump_param_ = need_dump_param;
  }
  virtual void SetDumpFieldVector(const std::vector<std::string>& dump_fields) {
    dump_fields_ = &dump_fields;
  }
  virtual void SetDumpParamVector(const std::vector<std::string>& dump_param) {
    dump_param_ = &dump_param;
  }
  virtual void SetChannelWriter(ChannelObject<std::string>* queue) {
    writer_.Reset(queue);
  }
190 191 192
  virtual void SetPlace(const paddle::platform::Place& place) {
    place_ = place;
  }
193 194 195
  virtual void SetReaderPlace(const paddle::platform::Place& place) {
    device_reader_->SetPlace(place);
  }
196
  virtual Scope* GetThreadScope() { return thread_scope_; }
T
Thunderbrook 已提交
197
  DataFeed* device_reader_ = nullptr;
198 199

 protected:
H
hutuxian 已提交
200 201 202
  virtual void DumpParam(const Scope& scope, const int batch_id);
  virtual void DumpField(const Scope& scope, int dump_mode,
                         int dump_interval = 10000);
J
jiaqi 已提交
203
  Scope* root_scope_ = nullptr;
204
  Scope* thread_scope_;
205
  paddle::platform::Place place_;
D
dongdaxiang 已提交
206 207
  int64_t batch_num_;
  FetchConfig fetch_config_;
208
  bool use_cvm_;
209
  bool no_cvm_;
T
Thunderbrook 已提交
210
  TrainerDesc trainer_desc_;
H
hutuxian 已提交
211 212 213 214 215 216

  // dump params or grads for debug
  bool need_dump_param_;
  bool need_dump_field_;
  const std::vector<std::string>* dump_param_;
  const std::vector<std::string>* dump_fields_;
217
  std::vector<std::string> all_param_;
H
hutuxian 已提交
218 219 220 221

  int dump_mode_ = 0;
  int dump_interval_ = 10000;
  ChannelWriter<std::string> writer_;
222 223 224 225 226 227 228 229 230
};

class CPUWorkerBase : public DeviceWorker {
 public:
  CPUWorkerBase() {}
  virtual ~CPUWorkerBase() {}
  virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
  virtual void TrainFiles() = 0;
  virtual void TrainFilesWithProfiler() {}
D
dongdaxiang 已提交
231
  virtual void PrintFetchVars() {}
232 233 234 235 236 237 238 239 240
  virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}

 protected:
  int thread_id_;
};

class HogwildWorker : public CPUWorkerBase {
 public:
  HogwildWorker() {}
241 242 243 244 245 246
  virtual ~HogwildWorker() {
    for (OperatorBase* op : ops_) {
      delete op;
    }
    std::vector<OperatorBase*>().swap(ops_);
  }
D
dongdaxiang 已提交
247
  virtual void Initialize(const TrainerDesc& desc);
248 249
  virtual void TrainFiles();
  virtual void TrainFilesWithProfiler();
D
dongdaxiang 已提交
250
  virtual void PrintFetchVars();
251 252
  virtual void CreateDeviceResource(const ProgramDesc& main_prog);
  virtual void BindingDataFeedMemory();
253 254
  template <typename T>
  void SetZero(LoDTensor* tensor, LoDTensor* root_tensor, int tensor_dim);
255 256 257 258

 protected:
  void CreateThreadOperators(const ProgramDesc& program);
  void CreateThreadScope(const ProgramDesc& program);
259

260 261
  std::vector<std::string> op_names_;
  std::vector<OperatorBase*> ops_;
262
  bool thread_barrier_;
263
  // Scope* thread_scope_;
264 265
  HogwildWorkerParameter param_;
  std::vector<std::string> skip_ops_;
266
  std::map<std::string, int> stat_var_name_map_;
267 268 269 270 271 272
};

class DownpourWorker : public HogwildWorker {
 public:
  DownpourWorker() {}
  virtual ~DownpourWorker() {}
273
  virtual void Initialize(const TrainerDesc& desc);
274
  virtual void TrainFiles();
275
  virtual void TrainFilesWithProfiler();
276 277 278 279 280 281 282

 protected:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
  void FillSparseValue(size_t table_id);
  void PushGradients();
  void CollectLabelInfo(size_t table_id);
283
  void AdjustInsWeight();
X
xujiaqi01 已提交
284 285 286
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();
287

288
  DownpourWorkerParameter param_;
289 290 291 292
  // copy table
  CopyTableConfig copy_table_config_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
293 294
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;
295
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
296 297 298 299
  // feasign
  std::map<uint64_t, std::vector<uint64_t>> features_;
  // feasign embedding
  std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
300 301 302 303 304 305 306 307 308
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  bool need_to_push_sparse_;
  // feasign stats
  std::map<uint64_t, std::vector<float>> feature_labels_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
309 310
  // feasign embedding gradient
  std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
311 312 313 314 315 316
  std::vector<::std::future<int32_t>> push_sparse_status_;
  bool dump_slot_;
  bool need_to_push_dense_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  float scale_datanorm_;
  std::vector<::std::future<int32_t>> push_dense_status_;
317 318
  // skipped ops
  std::vector<std::string> skip_ops_;
319 320 321 322 323
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
Z
zhang wenhui 已提交
324 325 326 327
  // multitask
  std::map<int32_t, uint64_t> cond2table_map_;
  std::set<uint64_t> condvalue_set_;
  bool flag_partial_push_;
328 329 330 331 332 333

 private:
  // std::vector<std::string> dump_param_;
  // just save the value in param_ for easy access
  // std::map<uint64_t, std::string> label_var_name_;
  // std::map<uint64_t, std::vector<std::string>> dense_value_names_;
334 335

  std::shared_ptr<PullDenseWorker> _pull_dense_worker;
336 337

  std::vector<float> nid_show_;
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
  // std::map<uint64_t, uint64_t> table_dependency_;
  // std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
};

class DownpourWorkerOpt : public DownpourWorker {
 public:
  DownpourWorkerOpt() {}
  virtual ~DownpourWorkerOpt() {}
  virtual void CreateDeviceResource(const ProgramDesc& main_prog);
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();

 protected:
  void CreateThreadOperatorsWithRerank(const ProgramDesc& program);
  std::vector<std::vector<OperatorBase*>> loss_ops_;
  std::vector<std::vector<std::string>> loss_op_names_;
  std::vector<std::string> loss_names_;
  std::string async_wait_name_;
  int async_index_ = -1;
  uint64_t async_tid_ = 0;
358 359
};

T
Thunderbrook 已提交
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
#ifdef PADDLE_WITH_PSLIB
class HeterCpuWorker : public HogwildWorker {
 public:
  HeterCpuWorker() {}
  virtual ~HeterCpuWorker() {}
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();
  virtual void TrainFilesWithProfiler();
  virtual void SetNeedDump(bool need_dump_field);
  virtual void SetChannelWriter(ChannelObject<std::string>* queue);
  virtual void SetWorkerNum(int num) { worker_num_ = num; }
  virtual void Schedule(int taskid);
  virtual void JumpContext(std::shared_ptr<HeterTask> task);
  virtual void CacheProgram(const ProgramDesc& main_program) {
    new (&program_) ProgramDesc(main_program);
  }
  virtual void GetXpuOpIndex();

 protected:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  std::shared_ptr<paddle::framework::HeterWrapper> heter_ptr_;
  std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
  void FillSparseValue(std::shared_ptr<HeterTask> task, size_t table_id);
  void PushGradients();
  void CollectLabelInfo(std::shared_ptr<HeterTask> task, size_t table_id);
  void AdjustInsWeight(std::shared_ptr<HeterTask> task);
  void DumpParam();
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();

 private:
  int mpi_rank_;
  int worker_num_;
  int xpu_begin_op_index_;
  int xpu_end_op_index_;
  ProgramDesc program_;
  HeterObjectPool<HeterTask> object_pool_;
  HeterList<int, std::shared_ptr<HeterTask>> run_queue_;
  HeterList<int, std::shared_ptr<HeterTask>> wait_queue_;
  bool need_dump_param_;
  std::vector<std::string> dump_param_;
  bool need_to_push_dense_;
  bool need_dump_field_;
  bool dump_slot_;
  bool need_to_push_sparse_;
  std::vector<std::string> dump_fields_;
  ChannelWriter<std::string> writer_;
  DownpourWorkerParameter param_;
  float scale_datanorm_;
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  platform::Place root_place_;
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;

  // skipped ops
  std::vector<std::string> skip_ops_;

  std::vector<::std::future<int32_t>> push_sparse_status_;
  std::vector<::std::future<int32_t>> push_dense_status_;

  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  std::vector<float> nid_show_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  // copy table
  CopyTableConfig copy_table_config_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
};
#endif

T
Thunderbrook 已提交
441 442 443 444 445 446 447 448 449 450 451 452 453 454
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU) && \
    (defined PADDLE_WITH_PSLIB)
class HeterBoxWorker : public HogwildWorker {
 public:
  HeterBoxWorker() {}
  virtual ~HeterBoxWorker() {}
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();
  virtual void SetNeedDump(bool need_dump_field);
  virtual void SetChannelWriter(ChannelObject<std::string>* queue);
  virtual void SetWorkerNum(int num) { worker_num_ = num; }
  virtual void CacheProgram(const ProgramDesc& main_program) {
    new (&program_) ProgramDesc(main_program);
  }
L
lilong12 已提交
455
  void ProduceTasks() override;
T
Thunderbrook 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
  virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
  virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
  virtual void TrainFilesWithProfiler() {}
  void ResetStat();

 protected:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  void FillSparseValue(std::shared_ptr<HeterTask> task, size_t table_id);
  void PushGradients();
  void CollectLabelInfo(std::shared_ptr<HeterTask> task, size_t table_id);
  void AdjustInsWeight(std::shared_ptr<HeterTask> task);
  void DumpParam();
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();

 private:
  int mpi_rank_;
  std::mutex mutex_;
  std::vector<std::string> send_var_list_;
  int worker_num_;
  ProgramDesc program_;
  HeterObjectPool<HeterTask> object_pool_;
  bool need_dump_param_;
  std::vector<std::string> dump_param_;
  bool need_to_push_dense_;
  bool need_dump_field_;
  bool dump_slot_;
  bool need_to_push_sparse_;
  std::vector<std::string> dump_fields_;
  ChannelWriter<std::string> writer_;
  DownpourWorkerParameter param_;
  float scale_datanorm_;
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
T
Thunderbrook 已提交
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  platform::Place root_place_;
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;

  // skipped ops
  std::vector<std::string> skip_ops_;

  std::vector<::std::future<int32_t>> push_sparse_status_;
  std::vector<::std::future<int32_t>> push_dense_status_;

  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  std::vector<float> nid_show_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  // copy table
  CopyTableConfig copy_table_config_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
  paddle::framework::Channel<std::shared_ptr<HeterTask>> pull_queue_;
  paddle::framework::Channel<std::shared_ptr<HeterTask>> push_queue_;
  cudaEvent_t event_;
  cudaStream_t copy_stream_;
  int batch_cnt_{0};
  std::atomic<int> done_cnt_{0};

  double total_time_;
  double read_time_;
  double pack_time_;
  double pull_sparse_local_time_;
  double op_all_time_;
  double xpu_op_time_;
  double xpu_wait_time_;
  double cpu_op_time_;
  double collect_label_time_;
  double fill_sparse_time_;
  double push_sparse_time_;
  double gpu_2_cpu_time_;
  double cpu_2_gpu_time_;
  uint64_t total_inst_;
};
#endif

#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB)
class PSGPUWorker : public HogwildWorker {
 public:
  PSGPUWorker() {}
  virtual ~PSGPUWorker() {}
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();
  virtual void SetNeedDump(bool need_dump_field);
  virtual void SetChannelWriter(ChannelObject<std::string>* queue);
  virtual void SetWorkerNum(int num) { worker_num_ = num; }
  virtual void CacheProgram(const ProgramDesc& main_program) {
    new (&program_) ProgramDesc(main_program);
  }
L
lilong12 已提交
554
  void ProduceTasks() override;
T
Thunderbrook 已提交
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
  virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; }
  virtual void SetEvent(const cudaEvent_t event) { event_ = event; }
  virtual void TrainFilesWithProfiler() {}
  void ResetStat();

 protected:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  void PushGradients();
  void DumpParam();
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();

 private:
  int mpi_rank_;
  std::mutex mutex_;
  std::vector<std::string> send_var_list_;
  int worker_num_;
  ProgramDesc program_;
  HeterObjectPool<HeterTask> object_pool_;
  bool need_dump_param_;
  std::vector<std::string> dump_param_;
  bool need_to_push_dense_;
  bool need_dump_field_;
  bool dump_slot_;
  bool need_to_push_sparse_;
  std::vector<std::string> dump_fields_;
  ChannelWriter<std::string> writer_;
  DownpourWorkerParameter param_;
  float scale_datanorm_;
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
T
Thunderbrook 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  platform::Place root_place_;
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;

  // skipped ops
  std::vector<std::string> skip_ops_;

  std::vector<::std::future<int32_t>> push_sparse_status_;
  std::vector<::std::future<int32_t>> push_dense_status_;

  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  std::vector<float> nid_show_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  // copy table
  CopyTableConfig copy_table_config_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
  paddle::framework::Channel<std::shared_ptr<HeterTask>> pull_queue_;
  paddle::framework::Channel<std::shared_ptr<HeterTask>> push_queue_;
  cudaEvent_t event_;
  cudaStream_t copy_stream_;
  int batch_cnt_{0};
  std::atomic<int> done_cnt_{0};

  double total_time_;
  double read_time_;
  double pack_time_;
  double pull_sparse_local_time_;
  double op_all_time_;
  double xpu_op_time_;
  double xpu_wait_time_;
  double cpu_op_time_;
  double collect_label_time_;
  double fill_sparse_time_;
  double push_sparse_time_;
  double gpu_2_cpu_time_;
  double cpu_2_gpu_time_;
  uint64_t total_inst_;
};
#endif

637
#if (defined PADDLE_WITH_NCCL) || (defined WITH_ASCEND_CL)
H
hutuxian 已提交
638 639
class SectionWorker : public DeviceWorker {
 public:
640
  SectionWorker() {}
H
hutuxian 已提交
641 642 643 644 645 646 647 648
  ~SectionWorker() override {}

  void Initialize(const TrainerDesc& desc) override;

  void BindingDataFeedMemory() override {}
  void CreateDeviceResource(const ProgramDesc& main_prog) override{};

  void TrainFiles() override;
649
  void TrainFilesWithProfiler() override{};
H
hutuxian 已提交
650 651 652 653 654

  void PrintFetchVars() override {}

  const platform::Place& place() const { return place_; }

L
lilong12 已提交
655
  void SetDeviceIndex(int tid) override {}
H
hutuxian 已提交
656
  void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
L
lilong12 已提交
657
  void SetMicrobatchNum(int num) { num_microbatches_ = num; }
L
lilong12 已提交
658 659 660
  void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
  void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
  void SetScheduleMode(int mode) { schedule_mode_ = mode; }
L
lilong12 已提交
661 662
  void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
    microbatch_scopes_ = scope;
H
hutuxian 已提交
663
  }
L
lilong12 已提交
664 665 666
  void SetMinibatchScope(const Scope* scope) { minibatch_scope_ = scope; }
  void SetSkipVars(const std::vector<std::string>& skip_vars) {
    skip_vars_ = skip_vars;
H
hutuxian 已提交
667
  }
L
lilong12 已提交
668 669 670 671 672 673 674 675 676
  void RunBackward(
      int micro_id, std::unique_ptr<GarbageCollector>&,
      std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
  void RunForward(
      int micro_id, std::unique_ptr<GarbageCollector>&,
      std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
  void RunUpdate(
      std::unique_ptr<GarbageCollector>&,
      std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
H
hutuxian 已提交
677 678 679 680

 protected:
  int section_id_;
  int thread_id_;
L
lilong12 已提交
681
  int num_microbatches_;
L
lilong12 已提交
682 683 684
  int num_pipeline_stages_;
  int pipeline_stage_;
  int schedule_mode_;  // 0 for GPipe and 1 for deepspeed
L
lilong12 已提交
685 686 687
  std::vector<Scope*> microbatch_scopes_;
  std::vector<std::string> skip_vars_;
  const Scope* minibatch_scope_;
H
hutuxian 已提交
688 689

  std::vector<std::unique_ptr<OperatorBase>> ops_;
L
lilong12 已提交
690 691
  std::shared_ptr<framework::ProgramDesc> program_;
  static uint64_t batch_id_;
H
hutuxian 已提交
692 693 694 695

  platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
L
lilong12 已提交
696

697 698
}  // namespace framework
}  // namespace paddle