ps_gpu_wrapper.h 18.8 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
Thunderbrook 已提交
17
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
18 19 20 21 22 23 24 25

#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
Y
yaoxuefeng 已提交
26
#include <unordered_set>
T
Thunderbrook 已提交
27
#include <vector>
28 29
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
30

Y
yaoxuefeng 已提交
31
#include "paddle/fluid/framework/data_set.h"
32 33
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
34
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
F
Fan Zhang 已提交
35
#include "paddle/fluid/framework/channel.h"
T
Thunderbrook 已提交
36 37 38
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
F
Fan Zhang 已提交
39 40
#include "paddle/fluid/framework/heter_util.h"
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
41
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
F
Fan Zhang 已提交
42 43 44 45 46 47
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
T
Thunderbrook 已提交
48 49 50 51 52
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
53
#ifdef PADDLE_WITH_PSCORE
54
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
T
Thunderbrook 已提交
55
#endif
T
Thunderbrook 已提交
56 57 58
#ifdef PADDLE_WITH_PSLIB
#include "afs_api.h"
#endif
Y
yaoxuefeng 已提交
59 60 61
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h"  // NOLINT
#endif
T
Thunderbrook 已提交
62 63 64 65

namespace paddle {
namespace framework {

Y
yaoxuefeng 已提交
66 67 68
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

F
Fan Zhang 已提交
69 70
class Dataset;

T
Thunderbrook 已提交
71 72 73 74 75
#ifdef PADDLE_WITH_PSLIB
class AfsWrapper {
 public:
  AfsWrapper() {}
  virtual ~AfsWrapper() {}
76 77 78 79
  void init(const std::string& fs_name,
            const std::string& fs_user,
            const std::string& pass_wd,
            const std::string& conf);
T
Thunderbrook 已提交
80 81 82 83 84 85 86 87 88
  int remove(const std::string& path);
  int mkdir(const std::string& path);
  std::vector<std::string> list(const std::string& path);

  int exist(const std::string& path);
  int upload(const std::string& local_file, const std::string& afs_file);

  int download(const std::string& local_file, const std::string& afs_file);

89 90 91 92
  int touchz(const std::string& path);
  std::string cat(const std::string& path);
  int mv(const std::string& old_path, const std::string& dest_path);

T
Thunderbrook 已提交
93 94 95 96 97
 private:
  paddle::ps::AfsApiWrapper afs_handler_;
};
#endif

T
Thunderbrook 已提交
98 99
class PSGPUWrapper {
 public:
F
Fan Zhang 已提交
100
  virtual ~PSGPUWrapper();
T
Thunderbrook 已提交
101 102 103 104

  PSGPUWrapper() {
    HeterPs_ = NULL;
    sleep_seconds_before_fail_exit_ = 300;
Y
yaoxuefeng 已提交
105 106 107 108
    pull_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
      pull_thread_pool_[i].reset(new ::ThreadPool(1));
    }
T
Thunderbrook 已提交
109 110 111 112
    hbm_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
      hbm_thread_pool_[i].reset(new ::ThreadPool(1));
    }
T
Thunderbrook 已提交
113 114
  }

115 116
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
Y
yaoxuefeng 已提交
117 118 119
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
120 121 122 123
                  const std::vector<int>& slot_dim,
                  const int hidden_size);
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
T
Thunderbrook 已提交
124 125 126 127
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
                  const int hidden_size);
128 129
  void PushSparseGrad(const paddle::platform::Place& place,
                      const int table_id,
T
Thunderbrook 已提交
130 131 132
                      const std::vector<const uint64_t*>& keys,
                      const std::vector<const float*>& grad_values,
                      const std::vector<int64_t>& slot_lengths,
133 134 135 136 137 138 139
                      const int hidden_size,
                      const int batch_size);
  void CopyKeys(const paddle::platform::Place& place,
                uint64_t** origin_keys,
                uint64_t* total_keys,
                const int64_t* gpu_len,
                int slot_num,
T
Thunderbrook 已提交
140
                int total_len);
141 142
  void CopyForPull(const paddle::platform::Place& place,
                   uint64_t** gpu_keys,
T
Thunderbrook 已提交
143
                   const std::vector<float*>& values,
144 145 146 147
                   const FeatureValue* total_values_gpu,
                   const int64_t* gpu_len,
                   const int slot_num,
                   const int hidden_size,
T
Thunderbrook 已提交
148
                   const int64_t total_length);
149 150
  void CopyForPull(const paddle::platform::Place& place,
                   uint64_t** gpu_keys,
Y
yaoxuefeng 已提交
151
                   const std::vector<float*>& values,
152 153 154 155 156 157
                   const FeatureValue* total_values_gpu,
                   const int64_t* gpu_len,
                   const int slot_num,
                   const int hidden_size,
                   const int64_t total_length,
                   int* gpu_dim);
T
Thunderbrook 已提交
158 159 160 161
  void CopyForPush(const paddle::platform::Place& place,
                   const std::vector<const float*>& grad_values,
                   FeaturePushValue* total_grad_values_gpu,
                   const std::vector<int64_t>& slot_lengths,
162 163
                   const int hidden_size,
                   const int64_t total_length,
T
Thunderbrook 已提交
164
                   const int batch_size);
Y
yaoxuefeng 已提交
165 166 167 168
  void CopyForPush(const paddle::platform::Place& place,
                   const std::vector<const float*>& grad_values,
                   FeaturePushValue* total_grad_values_gpu,
                   const std::vector<int64_t>& slot_lengths,
169 170
                   const uint64_t total_length,
                   const int batch_size,
Y
yaoxuefeng 已提交
171
                   size_t grad_value_size);
T
Thunderbrook 已提交
172

173
  void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
174 175
  void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
  void BuildPull(std::shared_ptr<HeterContext> gpu_task);
176 177 178 179
  void LoadIntoMemory(bool is_shuffle);
  void BeginPass();
  void EndPass();
  void start_build_thread();
180
  void pre_build_thread();
181
  void build_task();
182 183 184 185 186 187 188 189 190 191

  void Finalize() {
    VLOG(3) << "PSGPUWrapper Begin Finalize.";
    if (s_instance_ == nullptr) {
      return;
    }
    data_ready_channel_->Close();
    buildcpu_ready_channel_->Close();
    gpu_free_channel_->Close();
    running_ = false;
192 193
    VLOG(3) << "begin stop pre_build_threads_";
    pre_build_threads_.join();
194 195 196 197
    s_instance_ = nullptr;
    VLOG(3) << "PSGPUWrapper Finalize Finished.";
  }

T
Thunderbrook 已提交
198
  void InitializeGPU(const std::vector<int>& dev_ids) {
199
    if (s_instance_ != NULL && is_initialized_ == false) {
T
Thunderbrook 已提交
200
      VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
201
      is_initialized_ = true;
T
Thunderbrook 已提交
202 203
      resource_ = std::make_shared<HeterPsResource>(dev_ids);
      resource_->enable_p2p();
204
      keys_tensor.resize(resource_->total_device());
Y
yaoxuefeng 已提交
205 206 207 208 209 210 211 212 213
#ifdef PADDLE_WITH_GLOO
      auto gloo = paddle::framework::GlooWrapper::GetInstance();
      if (gloo->Size() > 1) {
        multi_node_ = 1;
      }
#else
      PADDLE_THROW(
          platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
F
Fan Zhang 已提交
214
#ifdef PADDLE_WITH_CUDA
215 216 217 218 219
      if (multi_node_) {
        int dev_size = dev_ids.size();
        // init inner comm
        inner_comms_.resize(dev_size);
        inter_ncclids_.resize(dev_size);
220 221
        platform::dynload::ncclCommInitAll(
            &(inner_comms_[0]), dev_size, &dev_ids[0]);
222 223 224 225 226 227 228 229 230 231
// init inter comm
#ifdef PADDLE_WITH_GLOO
        inter_comms_.resize(dev_size);
        if (gloo->Rank() == 0) {
          for (int i = 0; i < dev_size; ++i) {
            platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
          }
        }

        PADDLE_ENFORCE_EQ(
232 233
            gloo->IsInitialized(),
            true,
234 235 236 237 238 239 240 241
            platform::errors::PreconditionNotMet(
                "You must initialize the gloo environment first to use it."));
        gloo::BroadcastOptions opts(gloo->GetContext());
        opts.setOutput(&inter_ncclids_[0], dev_size);
        opts.setRoot(0);
        gloo::broadcast(opts);

        for (int i = 0; i < dev_size; ++i) {
242 243
          platform::dynload::ncclCommInitRank(
              &inter_comms_[i], gloo->Size(), inter_ncclids_[i], gloo->Rank());
244 245 246 247 248 249 250
        }
        node_size_ = gloo->Size();
#else
        PADDLE_THROW(
            platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
      }
F
Fan Zhang 已提交
251
#endif
Y
yaoxuefeng 已提交
252
      heter_devices_ = dev_ids;
253 254 255 256 257 258 259 260 261 262 263
      data_ready_channel_->Open();
      data_ready_channel_->SetCapacity(3);
      buildcpu_ready_channel_->Open();
      buildcpu_ready_channel_->SetCapacity(3);
      gpu_free_channel_->Open();
      gpu_free_channel_->SetCapacity(1);

      current_task_ = nullptr;
      gpu_free_channel_->Put(current_task_);

      table_id_ = 0;
264

265 266
      // start build cpu&gpu ps thread
      start_build_thread();
T
Thunderbrook 已提交
267 268
    }
  }
Y
yaoxuefeng 已提交
269

270 271 272 273 274 275
  void SetSparseSGD(float nonclk_coeff,
                    float clk_coeff,
                    float min_bound,
                    float max_bound,
                    float learning_rate,
                    float initial_g2sum,
Y
yaoxuefeng 已提交
276
                    float initial_range);
277 278 279 280 281 282
  void SetEmbedxSGD(float mf_create_thresholds,
                    float mf_learning_rate,
                    float mf_initial_g2sum,
                    float mf_initial_range,
                    float mf_min_bound,
                    float mf_max_bound);
Y
yaoxuefeng 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
  void InitializeGPUServer(std::unordered_map<std::string, float> config) {
    float nonclk_coeff = (config.find("nonclk_coeff") == config.end())
                             ? 1.0
                             : config["nonclk_coeff"];
    float clk_coeff =
        (config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"];
    float min_bound = (config.find("min_bound") == config.end())
                          ? -10000.0
                          : config["min_bound"];
    float max_bound = (config.find("max_bound") == config.end())
                          ? 10000.0
                          : config["max_bound"];
    float learning_rate = (config.find("learning_rate") == config.end())
                              ? 1.0
                              : config["learning_rate"];
    float initial_g2sum = (config.find("initial_g2sum") == config.end())
                              ? 1.0
                              : config["initial_g2sum"];
    float initial_range = (config.find("initial_range") == config.end())
                              ? 1.0
                              : config["initial_range"];

    // mf config settings
    float mf_create_thresholds =
        (config.find("mf_create_thresholds") == config.end())
            ? static_cast<float>(1.0)
            : config["mf_create_thresholds"];
    float mf_learning_rate = (config.find("mf_learning_rate") == config.end())
                                 ? 1.0
                                 : config["mf_learning_rate"];
    float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end())
                                 ? 1.0
                                 : config["mf_initial_g2sum"];
    float mf_initial_range = (config.find("mf_initial_range") == config.end())
                                 ? 1.0
                                 : config["mf_initial_range"];
    float mf_min_bound = (config.find("mf_min_bound") == config.end())
                             ? 1.0
                             : config["mf_min_bound"];
    float mf_max_bound = (config.find("mf_max_bound") == config.end())
                             ? 1.0
                             : config["mf_max_bound"];
    for (size_t i = 0; i < heter_devices_.size(); i++) {
F
Fan Zhang 已提交
326
#ifdef PADDLE_WITH_CUDA
327
      PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i]));
F
Fan Zhang 已提交
328 329 330
#elif defined(PADDLE_WITH_XPU_KP)
      PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i]));
#endif
331 332 333 334 335 336 337 338 339 340 341 342
      this->SetSparseSGD(nonclk_coeff,
                         clk_coeff,
                         min_bound,
                         max_bound,
                         learning_rate,
                         initial_g2sum,
                         initial_range);
      this->SetEmbedxSGD(mf_create_thresholds,
                         mf_learning_rate,
                         mf_initial_g2sum,
                         mf_initial_range,
                         mf_min_bound,
Y
yaoxuefeng 已提交
343 344 345
                         mf_max_bound);
    }
  }
F
Fan Zhang 已提交
346

347 348 349 350 351 352
  void SetDate(int year, int month, int day) {
    year_ = year;
    month_ = month;
    day_ = day;
  }

Y
yaoxuefeng 已提交
353 354
  void SetDataset(Dataset* dataset) { dataset_ = dataset; }

T
Thunderbrook 已提交
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
  // PSGPUWrapper singleton
  static std::shared_ptr<PSGPUWrapper> GetInstance() {
    if (NULL == s_instance_) {
      s_instance_.reset(new paddle::framework::PSGPUWrapper());
    }
    return s_instance_;
  }
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>& GetLocalTable(
      int table_id) {
    return local_tables_[table_id];
  }
  void SetSlotVector(const std::vector<int>& slot_vector) {
    slot_vector_ = slot_vector;
  }

Y
yaoxuefeng 已提交
370 371
  void SetSlotOffsetVector(const std::vector<int>& slot_offset_vector) {
    slot_offset_vector_ = slot_offset_vector;
Y
yaoxuefeng 已提交
372 373 374 375 376
    std::cout << "yxf set: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
Y
yaoxuefeng 已提交
377 378
  }

F
Fan Zhang 已提交
379
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
380 381 382
  void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
    slot_mf_dim_vector_ = slot_mf_dim_vector;
    assert(slot_mf_dim_vector_.size() == slot_vector_.size());
Y
yaoxuefeng 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
  }

  void InitSlotInfo() {
    if (slot_info_initialized_) {
      return;
    }
    SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
    auto slots_vec = dataset->GetSlots();
    slot_offset_vector_.clear();
    for (auto& slot : slot_vector_) {
      for (size_t i = 0; i < slots_vec.size(); ++i) {
        if (std::to_string(slot) == slots_vec[i]) {
          slot_offset_vector_.push_back(i);
          break;
        }
      }
    }
    std::cout << "psgpu wrapper use slots: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
    for (size_t i = 0; i < slot_mf_dim_vector_.size(); i++) {
Y
yaoxuefeng 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
      slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i];
    }

    std::unordered_set<int> dims_set;
    for (auto& it : slot_dim_map_) {
      dims_set.insert(it.second);
    }
    size_t num_of_dim = dims_set.size();
    index_dim_vec_.resize(num_of_dim);
    index_dim_vec_.assign(dims_set.begin(), dims_set.end());
    std::sort(index_dim_vec_.begin(), index_dim_vec_.end());
    std::unordered_map<int, int> dim_index_map;
    for (size_t i = 0; i < num_of_dim; i++) {
      dim_index_map[index_dim_vec_[i]] = i;
    }
421 422
    hbm_pools_.resize(resource_->total_device() * num_of_dim);
    mem_pools_.resize(resource_->total_device() * num_of_dim);
Y
yaoxuefeng 已提交
423 424 425 426 427 428 429 430 431 432 433
    max_mf_dim_ = index_dim_vec_.back();
    multi_mf_dim_ = (dim_index_map.size() >= 1) ? dim_index_map.size() : 0;
    resource_->set_multi_mf(multi_mf_dim_, max_mf_dim_);
    slot_index_vec_.resize(slot_mf_dim_vector_.size());
    for (size_t i = 0; i < slot_index_vec_.size(); i++) {
      slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]];
    }
    val_type_size_ =
        TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
    grad_type_size_ =
        TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
Y
yaoxuefeng 已提交
434
    slot_info_initialized_ = true;
Y
yaoxuefeng 已提交
435
  }
F
Fan Zhang 已提交
436
#endif
Y
yaoxuefeng 已提交
437

T
Thunderbrook 已提交
438 439
  void ShowOneTable(int index) { HeterPs_->show_one_table(index); }

T
Thunderbrook 已提交
440 441 442 443 444 445 446 447
  int UseAfsApi() { return use_afs_api_; }

#ifdef PADDLE_WITH_PSLIB
  std::shared_ptr<paddle::ps::AfsReader> OpenReader(
      const std::string& filename) {
    return afs_handler_.open_reader(filename);
  }

448 449 450 451
  void InitAfsApi(const std::string& fs_name,
                  const std::string& fs_user,
                  const std::string& pass_wd,
                  const std::string& conf);
T
Thunderbrook 已提交
452 453
#endif

T
Thunderbrook 已提交
454 455
 private:
  static std::shared_ptr<PSGPUWrapper> s_instance_;
Y
yaoxuefeng 已提交
456
  Dataset* dataset_;
T
Thunderbrook 已提交
457 458 459
#ifdef PADDLE_WITH_PSLIB
  paddle::ps::AfsApiWrapper afs_handler_;
#endif
T
Thunderbrook 已提交
460
  std::unordered_map<
461 462
      uint64_t,
      std::vector<std::unordered_map<uint64_t, std::vector<float>>>>
T
Thunderbrook 已提交
463 464 465 466 467 468
      local_tables_;
  HeterPsBase* HeterPs_;
  std::vector<LoDTensor> keys_tensor;  // Cache for pull_sparse
  std::shared_ptr<HeterPsResource> resource_;
  int32_t sleep_seconds_before_fail_exit_;
  std::vector<int> slot_vector_;
Y
yaoxuefeng 已提交
469 470 471 472 473 474 475 476 477
  std::vector<int> slot_offset_vector_;
  std::vector<int> slot_mf_dim_vector_;
  std::unordered_map<int, int> slot_dim_map_;
  std::vector<int> slot_index_vec_;
  std::vector<int> index_dim_vec_;
  int multi_mf_dim_{0};
  int max_mf_dim_{0};
  size_t val_type_size_{0};
  size_t grad_type_size_{0};
Y
yaoxuefeng 已提交
478 479 480 481 482 483

  double time_1 = 0.0;
  double time_2 = 0.0;
  double time_3 = 0.0;
  double time_4 = 0.0;

T
Thunderbrook 已提交
484
  int multi_node_{0};
485
  int node_size_;
486
  uint64_t table_id_;
F
Fan Zhang 已提交
487
#ifdef PADDLE_WITH_CUDA
488 489 490
  std::vector<ncclComm_t> inner_comms_;
  std::vector<ncclComm_t> inter_comms_;
  std::vector<ncclUniqueId> inter_ncclids_;
F
Fan Zhang 已提交
491
#endif
Y
yaoxuefeng 已提交
492 493 494
  std::vector<int> heter_devices_;
  std::unordered_set<std::string> gpu_ps_config_keys_;
  HeterObjectPool<HeterContext> gpu_task_pool_;
495
  std::vector<std::vector<robin_hood::unordered_set<uint64_t>>> thread_keys_;
496 497
  std::vector<std::vector<std::vector<robin_hood::unordered_set<uint64_t>>>>
      thread_dim_keys_;
Y
yaoxuefeng 已提交
498 499 500
  int thread_keys_thread_num_ = 37;
  int thread_keys_shard_num_ = 37;
  uint64_t max_fea_num_per_pass_ = 5000000000;
501 502 503
  int year_;
  int month_;
  int day_;
Y
yaoxuefeng 已提交
504
  bool slot_info_initialized_ = false;
T
Thunderbrook 已提交
505
  int use_afs_api_ = 0;
T
Thunderbrook 已提交
506

F
Fan Zhang 已提交
507
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
508 509 510
  std::vector<MemoryPool*> mem_pools_;
  std::vector<HBMMemoryPool*> hbm_pools_;  // in multi mfdim, one table need hbm
                                           // pools of totol dims number
F
Fan Zhang 已提交
511
#endif
Y
yaoxuefeng 已提交
512

513 514 515 516 517 518 519 520 521 522 523 524 525
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      data_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      buildcpu_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      gpu_free_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<HeterContext> current_task_ = nullptr;
526
  std::thread pre_build_threads_;
527
  bool running_ = false;
Y
yaoxuefeng 已提交
528
  std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
T
Thunderbrook 已提交
529
  std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
530

T
Thunderbrook 已提交
531 532 533 534 535 536 537
 protected:
  static bool is_initialized_;
};

}  // end namespace framework
}  // end namespace paddle
#endif