device_worker.h 23.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

H
hutuxian 已提交
17
#include <atomic>
18 19 20 21
#include <fstream>
#include <map>
#include <memory>
#include <mutex>  // NOLINT
Z
zhang wenhui 已提交
22
#include <set>
23
#include <string>
X
xujiaqi01 已提交
24 25 26 27
#include <thread>         // NOLINT
#include <unordered_map>  // NOLINT
#include <unordered_set>  // NOLINT
#include <utility>        // NOLINT
28 29 30
#include <vector>

#include "paddle/fluid/framework/data_feed.h"
31
#include "paddle/fluid/framework/executor_gc_helper.h"
T
Thunderbrook 已提交
32
#include "paddle/fluid/framework/heter_util.h"
33 34 35 36 37 38 39 40
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/platform/place.h"
D
dongdaxiang 已提交
41
#include "paddle/fluid/platform/port.h"
42 43
#include "paddle/fluid/platform/timer.h"

W
wanghuancoder 已提交
44 45 46 47 48 49 50
namespace paddle {
namespace framework {
class ProgramDesc;
class Scope;
}  // namespace framework
}  // namespace paddle

51
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
52
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
H
hutuxian 已提交
53 54
#endif

55 56 57
namespace paddle {
namespace framework {

58
std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end);
59 60 61
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);

62 63
class FleetWrapper;

T
Thunderbrook 已提交
64 65 66 67
#ifdef PADDLE_WITH_PSLIB
class HeterWrapper;
#endif

68 69 70 71
class PullDenseWorker {
 public:
  virtual ~PullDenseWorker() {}
  virtual void Initialize(const TrainerDesc& param);
72 73
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  void AddStream(const gpuStream_t stream) { copy_streams_.push_back(stream); }
T
Thunderbrook 已提交
74
#endif
T
Thunderbrook 已提交
75

76 77
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
    defined(PADDLE_WITH_XPU)
T
Thunderbrook 已提交
78 79 80 81 82 83
  void AddPlace(const paddle::platform::Place place) {
    places_.push_back(place);
  }

  void AddThreadScope(Scope* scope) { thread_scopes_.push_back(scope); }
#endif
84 85
  int Start();
  void Stop();
86
  void SetRootScope(Scope* scope) { root_scope_ = scope; }
87 88 89
  void IncreaseThreadVersion(int thread_id, uint64_t table_id);
  void ResetThreadVersion(uint64_t table_id);
  void Wait(std::vector<::std::future<int32_t>>* status_vec);
90
  void PullDense(bool force_update = false);
T
Thunderbrook 已提交
91
  void CreatePinVar();
T
Thunderbrook 已提交
92
  void MergeDenseParam();
93 94
  int GetThreadIdByScope(const Scope* scope);
  void SetThreadIdByScope(const Scope* scope, int tid);
95 96 97 98 99 100 101
  static std::shared_ptr<PullDenseWorker> GetInstance() {
    if (NULL == s_instance_) {
      s_instance_.reset(new paddle::framework::PullDenseWorker());
    }
    return s_instance_;
  }

102 103
  static std::shared_ptr<PullDenseWorker> s_instance_;

104
 private:
105
  PullDenseWorker() : root_scope_(NULL) {}
106 107 108 109 110 111
  void Run();
  bool CheckUpdateParam(uint64_t table_id);

 private:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  PullDenseWorkerParameter param_;
H
heqiaozhi 已提交
112
  DownpourWorkerParameter dwp_param_;
113 114 115
  Scope* root_scope_;
  bool running_;

D
dongdaxiang 已提交
116 117 118 119 120
  static std::map<uint64_t, uint64_t> last_versions_;
  static std::map<uint64_t, uint64_t> current_version_;
  static std::mutex mutex_for_version_;
  static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
  static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
121 122 123 124 125 126 127 128 129 130 131 132 133 134

  std::thread t_;
  int thread_num_;
  int sleep_time_ms_;
  int threshold_;

  std::vector<::std::future<int32_t>> pull_dense_status_;
  uint32_t pull_dense_fail_times_ = 0;
  std::vector<float> base_norm_param_;
  std::vector<float> mean_;
  std::vector<float> scale_;
  float squared_sum_epsilon_ = 1e-4;
  std::mutex mutex_for_mean_scale_;
  float total_batch_num_ = 0;
135
  std::unordered_map<const Scope*, int> scope_to_thread_id_;
T
Thunderbrook 已提交
136

137 138
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  std::vector<gpuStream_t> copy_streams_;
T
Thunderbrook 已提交
139
#endif
T
Thunderbrook 已提交
140 141
  std::vector<paddle::platform::Place> places_;
  std::vector<Scope*> thread_scopes_;
142 143 144 145 146
};

// should incorporate different type of device
class DeviceWorker {
 public:
147 148 149 150
  DeviceWorker() {
    no_cvm_ = true;
    use_cvm_ = false;
  }
151 152
  virtual ~DeviceWorker() {}
  virtual void Initialize(const TrainerDesc& desc) = 0;
H
hutuxian 已提交
153
  virtual void InitRandomDumpConfig(const TrainerDesc& desc);
154 155
  virtual void SetDeviceIndex(int tid) = 0;
  virtual void TrainFiles() = 0;
D
dongdaxiang 已提交
156
  virtual void PrintFetchVars() = 0;
157 158 159 160 161
  virtual void TrainFilesWithProfiler() = 0;
  virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
  // will make this zero copy in the future
  virtual void BindingDataFeedMemory() = 0;
  virtual void SetRootScope(Scope* root_scope);
J
jiaqi 已提交
162
  virtual void SetDataFeed(DataFeed* data_feed);
T
Thunderbrook 已提交
163 164
  virtual void SetWorkerNum(int num) {}
  virtual void CacheProgram(const ProgramDesc& main_program) {}
T
Thunderbrook 已提交
165
  virtual void ProduceTasks() {}
T
Thunderbrook 已提交
166
  virtual void GetXpuOpIndex() {}
T
Thunderbrook 已提交
167
  virtual void Schedule(int taskid) {}
168 169 170
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  virtual void SetStream(const gpuStream_t stream) {}
  virtual void SetEvent(const gpuEvent_t event) {}
T
Thunderbrook 已提交
171
#endif
H
hutuxian 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  virtual void SetNeedDumpField(bool need_dump_field) {
    need_dump_field_ = need_dump_field;
  }
  virtual void SetNeedDumpParam(bool need_dump_param) {
    need_dump_param_ = need_dump_param;
  }
  virtual void SetDumpFieldVector(const std::vector<std::string>& dump_fields) {
    dump_fields_ = &dump_fields;
  }
  virtual void SetDumpParamVector(const std::vector<std::string>& dump_param) {
    dump_param_ = &dump_param;
  }
  virtual void SetChannelWriter(ChannelObject<std::string>* queue) {
    writer_.Reset(queue);
  }
187 188 189
  virtual void SetPlace(const paddle::platform::Place& place) {
    place_ = place;
  }
190 191 192
  virtual void SetReaderPlace(const paddle::platform::Place& place) {
    device_reader_->SetPlace(place);
  }
193 194 195
  virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) {
    dev_ctx_ = dev_ctx;
  }
196
  virtual Scope* GetThreadScope() { return thread_scope_; }
T
Thunderbrook 已提交
197
  DataFeed* device_reader_ = nullptr;
198 199

 protected:
H
hutuxian 已提交
200 201 202
  virtual void DumpParam(const Scope& scope, const int batch_id);
  virtual void DumpField(const Scope& scope, int dump_mode,
                         int dump_interval = 10000);
J
jiaqi 已提交
203
  Scope* root_scope_ = nullptr;
204
  Scope* thread_scope_;
205
  paddle::platform::Place place_;
T
tangwei12 已提交
206
  int64_t batch_num_ = 0;
D
dongdaxiang 已提交
207
  FetchConfig fetch_config_;
208
  bool use_cvm_;
209
  bool no_cvm_;
210
  bool scale_sparse_gradient_with_batch_size_;
T
Thunderbrook 已提交
211
  TrainerDesc trainer_desc_;
H
hutuxian 已提交
212 213 214 215 216 217

  // dump params or grads for debug
  bool need_dump_param_;
  bool need_dump_field_;
  const std::vector<std::string>* dump_param_;
  const std::vector<std::string>* dump_fields_;
218
  std::vector<std::string> all_param_;
H
hutuxian 已提交
219 220 221 222

  int dump_mode_ = 0;
  int dump_interval_ = 10000;
  ChannelWriter<std::string> writer_;
223
  platform::DeviceContext* dev_ctx_ = nullptr;
224 225 226 227 228 229 230 231 232
};

class CPUWorkerBase : public DeviceWorker {
 public:
  CPUWorkerBase() {}
  virtual ~CPUWorkerBase() {}
  virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
  virtual void TrainFiles() = 0;
  virtual void TrainFilesWithProfiler() {}
D
dongdaxiang 已提交
233
  virtual void PrintFetchVars() {}
234 235 236 237 238 239 240 241 242
  virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}

 protected:
  int thread_id_;
};

class HogwildWorker : public CPUWorkerBase {
 public:
  HogwildWorker() {}
243 244 245 246 247 248
  virtual ~HogwildWorker() {
    for (OperatorBase* op : ops_) {
      delete op;
    }
    std::vector<OperatorBase*>().swap(ops_);
  }
D
dongdaxiang 已提交
249
  virtual void Initialize(const TrainerDesc& desc);
250 251
  virtual void TrainFiles();
  virtual void TrainFilesWithProfiler();
D
dongdaxiang 已提交
252
  virtual void PrintFetchVars();
253 254
  virtual void CreateDeviceResource(const ProgramDesc& main_prog);
  virtual void BindingDataFeedMemory();
255 256
  template <typename T>
  void SetZero(LoDTensor* tensor, LoDTensor* root_tensor, int tensor_dim);
257 258 259 260

 protected:
  void CreateThreadOperators(const ProgramDesc& program);
  void CreateThreadScope(const ProgramDesc& program);
261

262 263
  std::vector<std::string> op_names_;
  std::vector<OperatorBase*> ops_;
264
  bool thread_barrier_;
265
  // Scope* thread_scope_;
266 267
  HogwildWorkerParameter param_;
  std::vector<std::string> skip_ops_;
268
  std::map<std::string, int> stat_var_name_map_;
269 270 271 272 273 274
};

class DownpourWorker : public HogwildWorker {
 public:
  DownpourWorker() {}
  virtual ~DownpourWorker() {}
275
  virtual void Initialize(const TrainerDesc& desc);
276
  virtual void TrainFiles();
277
  virtual void TrainFilesWithProfiler();
278 279 280 281 282 283 284

 protected:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
  void FillSparseValue(size_t table_id);
  void PushGradients();
  void CollectLabelInfo(size_t table_id);
285
  void AdjustInsWeight();
X
xujiaqi01 已提交
286 287 288
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();
289

290
  DownpourWorkerParameter param_;
291 292 293 294
  // copy table
  CopyTableConfig copy_table_config_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
295 296
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;
297
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
298 299 300 301
  // feasign
  std::map<uint64_t, std::vector<uint64_t>> features_;
  // feasign embedding
  std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
302 303 304 305 306 307 308 309 310
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  bool need_to_push_sparse_;
  // feasign stats
  std::map<uint64_t, std::vector<float>> feature_labels_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
311 312
  // feasign embedding gradient
  std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
313 314 315 316 317 318
  std::vector<::std::future<int32_t>> push_sparse_status_;
  bool dump_slot_;
  bool need_to_push_dense_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  float scale_datanorm_;
  std::vector<::std::future<int32_t>> push_dense_status_;
319 320
  // skipped ops
  std::vector<std::string> skip_ops_;
321 322 323 324 325
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
Z
zhang wenhui 已提交
326 327 328 329
  // multitask
  std::map<int32_t, uint64_t> cond2table_map_;
  std::set<uint64_t> condvalue_set_;
  bool flag_partial_push_;
330 331 332 333 334 335

 private:
  // std::vector<std::string> dump_param_;
  // just save the value in param_ for easy access
  // std::map<uint64_t, std::string> label_var_name_;
  // std::map<uint64_t, std::vector<std::string>> dense_value_names_;
336 337

  std::shared_ptr<PullDenseWorker> _pull_dense_worker;
338 339

  std::vector<float> nid_show_;
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
  // std::map<uint64_t, uint64_t> table_dependency_;
  // std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
};

class DownpourWorkerOpt : public DownpourWorker {
 public:
  DownpourWorkerOpt() {}
  virtual ~DownpourWorkerOpt() {}
  virtual void CreateDeviceResource(const ProgramDesc& main_prog);
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();

 protected:
  void CreateThreadOperatorsWithRerank(const ProgramDesc& program);
  std::vector<std::vector<OperatorBase*>> loss_ops_;
  std::vector<std::vector<std::string>> loss_op_names_;
  std::vector<std::string> loss_names_;
  std::string async_wait_name_;
  int async_index_ = -1;
  uint64_t async_tid_ = 0;
360 361
};

T
Thunderbrook 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
#ifdef PADDLE_WITH_PSLIB
class HeterCpuWorker : public HogwildWorker {
 public:
  HeterCpuWorker() {}
  virtual ~HeterCpuWorker() {}
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();
  virtual void TrainFilesWithProfiler();
  virtual void SetNeedDump(bool need_dump_field);
  virtual void SetChannelWriter(ChannelObject<std::string>* queue);
  virtual void SetWorkerNum(int num) { worker_num_ = num; }
  virtual void Schedule(int taskid);
  virtual void JumpContext(std::shared_ptr<HeterTask> task);
  virtual void CacheProgram(const ProgramDesc& main_program) {
    new (&program_) ProgramDesc(main_program);
  }
  virtual void GetXpuOpIndex();

 protected:
  std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
  std::shared_ptr<paddle::framework::HeterWrapper> heter_ptr_;
  std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
  void FillSparseValue(std::shared_ptr<HeterTask> task, size_t table_id);
  void PushGradients();
  void CollectLabelInfo(std::shared_ptr<HeterTask> task, size_t table_id);
  void AdjustInsWeight(std::shared_ptr<HeterTask> task);
  void DumpParam();
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();

 private:
  int mpi_rank_;
  int worker_num_;
  int xpu_begin_op_index_;
  int xpu_end_op_index_;
  ProgramDesc program_;
  HeterObjectPool<HeterTask> object_pool_;
  HeterList<int, std::shared_ptr<HeterTask>> run_queue_;
  HeterList<int, std::shared_ptr<HeterTask>> wait_queue_;
  bool need_dump_param_;
  std::vector<std::string> dump_param_;
  bool need_to_push_dense_;
  bool need_dump_field_;
  bool dump_slot_;
  bool need_to_push_sparse_;
  std::vector<std::string> dump_fields_;
  ChannelWriter<std::string> writer_;
  DownpourWorkerParameter param_;
  float scale_datanorm_;
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  platform::Place root_place_;
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;

  // skipped ops
  std::vector<std::string> skip_ops_;

  std::vector<::std::future<int32_t>> push_sparse_status_;
  std::vector<::std::future<int32_t>> push_dense_status_;

  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  std::vector<float> nid_show_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  // copy table
  CopyTableConfig copy_table_config_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
};
#endif

443 444
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \
    (defined PADDLE_WITH_PSLIB)
T
Thunderbrook 已提交
445 446 447 448 449 450
class PSGPUWorker : public HogwildWorker {
 public:
  PSGPUWorker() {}
  virtual ~PSGPUWorker() {}
  virtual void Initialize(const TrainerDesc& desc);
  virtual void TrainFiles();
451
  virtual void TrainFilesWithProfiler();
T
Thunderbrook 已提交
452 453 454 455 456
  virtual void SetChannelWriter(ChannelObject<std::string>* queue);
  virtual void SetWorkerNum(int num) { worker_num_ = num; }
  virtual void CacheProgram(const ProgramDesc& main_program) {
    new (&program_) ProgramDesc(main_program);
  }
457
  void ProduceTasks() override;
458 459
  virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; }
  virtual void SetEvent(const gpuEvent_t event) { event_ = event; }
T
Thunderbrook 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
  void ResetStat();

 protected:
  void PushGradients();
  void CopySparseTable();
  void CopyDenseTable();
  void CopyDenseVars();

 private:
  int mpi_rank_;
  std::mutex mutex_;
  int worker_num_;
  ProgramDesc program_;
  HeterObjectPool<HeterTask> object_pool_;
  bool need_to_push_dense_;
  bool dump_slot_;
  bool need_to_push_sparse_;
  DownpourWorkerParameter param_;
  float scale_datanorm_;
  // just save the value in param_ for easy access
  std::map<uint64_t, std::string> label_var_name_;
  std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
T
Thunderbrook 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
  std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
  std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
  std::map<uint64_t, std::vector<std::string>> dense_value_names_;
  std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
  platform::Place root_place_;
  // actually pushed feasign of each table
  std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;

  // skipped ops
  std::vector<std::string> skip_ops_;

  std::vector<::std::future<int32_t>> push_sparse_status_;
  std::vector<::std::future<int32_t>> push_dense_status_;

  // adjust ins weight
  AdjustInsWeightConfig adjust_ins_weight_config_;
  std::vector<float> nid_show_;
  // check nan and inf during training
  std::vector<std::string> check_nan_var_names_;
  // copy table
  CopyTableConfig copy_table_config_;
  std::map<uint64_t, uint64_t> table_dependency_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
  std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
  std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
  paddle::framework::Channel<std::shared_ptr<HeterTask>> pull_queue_;
  paddle::framework::Channel<std::shared_ptr<HeterTask>> push_queue_;
509 510
  gpuEvent_t event_;
  gpuStream_t copy_stream_;
T
Thunderbrook 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
  int batch_cnt_{0};
  std::atomic<int> done_cnt_{0};

  double total_time_;
  double read_time_;
  double pack_time_;
  double pull_sparse_local_time_;
  double op_all_time_;
  double xpu_op_time_;
  double xpu_wait_time_;
  double cpu_op_time_;
  double collect_label_time_;
  double fill_sparse_time_;
  double push_sparse_time_;
  double gpu_2_cpu_time_;
  double cpu_2_gpu_time_;
  uint64_t total_inst_;
};
#endif

531
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
532
    defined(PADDLE_WITH_ASCEND_CL)
H
hutuxian 已提交
533 534
class SectionWorker : public DeviceWorker {
 public:
535
  SectionWorker() {}
H
hutuxian 已提交
536 537 538
  ~SectionWorker() override {}

  void Initialize(const TrainerDesc& desc) override;
539
  void PrepareUnusedVar();
H
hutuxian 已提交
540 541 542 543 544

  void BindingDataFeedMemory() override {}
  void CreateDeviceResource(const ProgramDesc& main_prog) override{};

  void TrainFiles() override;
545
  void TrainFilesWithProfiler() override{};
H
hutuxian 已提交
546 547 548 549 550

  void PrintFetchVars() override {}

  const platform::Place& place() const { return place_; }

L
lilong12 已提交
551
  void SetDeviceIndex(int tid) override {}
H
hutuxian 已提交
552
  void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
L
lilong12 已提交
553
  void SetMicrobatchNum(int num) { num_microbatches_ = num; }
554 555 556
  void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
  void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
  void SetScheduleMode(int mode) { schedule_mode_ = mode; }
L
lilong12 已提交
557 558
  void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
    microbatch_scopes_ = scope;
H
hutuxian 已提交
559
  }
L
lilong12 已提交
560 561 562
  void SetMinibatchScope(const Scope* scope) { minibatch_scope_ = scope; }
  void SetSkipVars(const std::vector<std::string>& skip_vars) {
    skip_vars_ = skip_vars;
H
hutuxian 已提交
563
  }
564 565 566 567 568 569 570 571 572
  void RunBackward(
      int micro_id, std::unique_ptr<GarbageCollector>&,
      std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
  void RunForward(
      int micro_id, std::unique_ptr<GarbageCollector>&,
      std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
  void RunUpdate(
      std::unique_ptr<GarbageCollector>&,
      std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
573 574
  void RunFThenB(std::unique_ptr<GarbageCollector>&);
  void Run1F1B(std::unique_ptr<GarbageCollector>&);
H
hutuxian 已提交
575 576 577 578

 protected:
  int section_id_;
  int thread_id_;
L
lilong12 已提交
579
  int num_microbatches_;
580 581 582
  int num_pipeline_stages_;
  int pipeline_stage_;
  int schedule_mode_;  // 0 for F-then-B and 1 for 1F1B
L
lilong12 已提交
583 584
  std::vector<Scope*> microbatch_scopes_;
  const Scope* minibatch_scope_;
H
hutuxian 已提交
585

586 587 588 589
  // skip&backward vars are only used in 1F1B
  std::vector<std::string> skip_vars_;
  std::vector<std::string> backward_send_vars_;

H
hutuxian 已提交
590
  std::vector<std::unique_ptr<OperatorBase>> ops_;
591 592 593 594
  std::vector<OperatorBase*> forward_and_lr_ops_;
  std::vector<OperatorBase*> forward_ops_;
  std::vector<OperatorBase*> backward_ops_;
  std::vector<OperatorBase*> optimizer_ops_;
L
lilong12 已提交
595
  std::shared_ptr<framework::ProgramDesc> program_;
596 597
  std::unordered_map<const OperatorBase*, std::vector<std::string>>
      unused_vars_;
L
lilong12 已提交
598
  static uint64_t batch_id_;
H
hutuxian 已提交
599 600 601 602

  platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
L
lilong12 已提交
603

604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
#if defined(PADDLE_WITH_PSCORE)
class HeterSectionWorker : public DeviceWorker {
 public:
  HeterSectionWorker() {}
  ~HeterSectionWorker() override {}

  void Initialize(const TrainerDesc& desc) override;
  void CreateDeviceResource(const ProgramDesc& main_prog) override{};

  void TrainFiles() override;
  void TrainFilesWithProfiler() override;

  void BindingDataFeedMemory() override {}
  void BindingDataFeedMemory(int micro_id);
  void PrintFetchVars() override;
  const platform::Place& place() const { return place_; }

  void SetDeviceIndex(int tid) override { thread_id_ = tid; }
  void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
  void SetMicrobatchNum(int num) { num_microbatches_ = num; }
  void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
  void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
  std::shared_ptr<std::vector<Scope*>> GetMicrobatchScopes() {
    return microbatch_scopes_;
  }
629 630 631 632
  void SetMicrobatchScopes(
      std::shared_ptr<std::vector<Scope*>> microbatch_scopes) {
    microbatch_scopes_ = microbatch_scopes;
  }
633 634 635 636
  using SHARED_THREAD_QUEUE = std::shared_ptr<
      ::paddle::framework::BlockingQueue<std::pair<std::string, int>>>;

  SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; }
637 638 639
  void SetThreadQueue(SHARED_THREAD_QUEUE thread_queue) {
    thread_queue_ = thread_queue;
  }
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
  void CopyParameters(int microbatch_id, const ProgramDesc& program,
                      const platform::Place& place);
  void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; }
  void SetTrainerId(int trainer_id) { this->trainer_id_ = trainer_id; }
  void SetTrainers(int trainers) { this->trainers_ = trainers; }
  void CreateMicrobatchScopes();
  void RunForward(int micro_id);
  void RunBackward(int micro_id);
  void RunListen();
  void MiniBatchBarrier();
  void Run();
  void BatchPostProcess();
  void SetDebug(bool debug) { debug_ = debug; }
  Scope* GetThreadScope() override { return minibatch_scope_; }

  // multi-stream
  // #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  //  void SetStream(const gpuStream_t stream) override {}
  //  void SetEvent(const gpuEvent_t event) override {}
  // #endif

 protected:
  int trainer_id_;
  int trainers_;
  int thread_num_;
  int thread_id_;
  int num_microbatches_;
  int num_pipeline_stages_;
  int pipeline_stage_;
  bool epoch_finish_;

  std::shared_ptr<std::vector<Scope*>> microbatch_scopes_;
  Scope* minibatch_scope_;
  std::vector<int> micro_ids_{};
  std::unique_ptr<OperatorBase> listen_op_{nullptr};
  std::vector<std::unique_ptr<OperatorBase>> forward_ops_;
  std::vector<std::unique_ptr<OperatorBase>> backward_ops_;
  std::shared_ptr<framework::ProgramDesc> program_;
  std::shared_ptr<
      ::paddle::framework::BlockingQueue<std::pair<std::string, int>>>
      thread_queue_;
  static uint64_t batch_id_;
  uint64_t total_ins_num_ = 0;
  platform::DeviceContext* dev_ctx_ = nullptr;

  bool debug_ = false;
  std::vector<double> op_total_time_;
  std::vector<std::string> op_name_;
  platform::Timer timeline_;
  double total_time_ = 0.0;
  double read_time_ = 0.0;
};
#endif

694 695
}  // namespace framework
}  // namespace paddle