ps_gpu_wrapper.h 18.6 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
Thunderbrook 已提交
17
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
18 19 20 21 22 23 24 25

#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
Y
yaoxuefeng 已提交
26
#include <unordered_set>
T
Thunderbrook 已提交
27
#include <vector>
28 29
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
30

Y
yaoxuefeng 已提交
31
#include "paddle/fluid/framework/data_set.h"
32 33
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
34
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
F
Fan Zhang 已提交
35
#include "paddle/fluid/framework/channel.h"
T
Thunderbrook 已提交
36 37 38
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
F
Fan Zhang 已提交
39 40
#include "paddle/fluid/framework/heter_util.h"
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
41
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
F
Fan Zhang 已提交
42 43 44 45 46 47
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
T
Thunderbrook 已提交
48 49 50 51 52
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
53
#ifdef PADDLE_WITH_PSCORE
54
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
T
Thunderbrook 已提交
55
#endif
T
Thunderbrook 已提交
56 57 58
#ifdef PADDLE_WITH_PSLIB
#include "afs_api.h"
#endif
Y
yaoxuefeng 已提交
59 60 61
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h"  // NOLINT
#endif
62
#include "paddle/fluid/framework/fleet/heter_ps/log_patch.h"
T
Thunderbrook 已提交
63 64 65 66

namespace paddle {
namespace framework {

Y
yaoxuefeng 已提交
67 68 69
#define TYPEALIGN(ALIGNVAL, LEN) \
  (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1)))

F
Fan Zhang 已提交
70 71
class Dataset;

T
Thunderbrook 已提交
72 73 74 75 76
#ifdef PADDLE_WITH_PSLIB
class AfsWrapper {
 public:
  AfsWrapper() {}
  virtual ~AfsWrapper() {}
77 78 79 80
  void init(const std::string& fs_name,
            const std::string& fs_user,
            const std::string& pass_wd,
            const std::string& conf);
T
Thunderbrook 已提交
81 82 83 84 85 86 87 88 89
  int remove(const std::string& path);
  int mkdir(const std::string& path);
  std::vector<std::string> list(const std::string& path);

  int exist(const std::string& path);
  int upload(const std::string& local_file, const std::string& afs_file);

  int download(const std::string& local_file, const std::string& afs_file);

90 91 92 93
  int touchz(const std::string& path);
  std::string cat(const std::string& path);
  int mv(const std::string& old_path, const std::string& dest_path);

T
Thunderbrook 已提交
94 95 96 97 98
 private:
  paddle::ps::AfsApiWrapper afs_handler_;
};
#endif

T
Thunderbrook 已提交
99 100
class PSGPUWrapper {
 public:
F
Fan Zhang 已提交
101
  virtual ~PSGPUWrapper();
T
Thunderbrook 已提交
102 103 104 105

  PSGPUWrapper() {
    HeterPs_ = NULL;
    sleep_seconds_before_fail_exit_ = 300;
Y
yaoxuefeng 已提交
106 107 108 109
    pull_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
      pull_thread_pool_[i].reset(new ::ThreadPool(1));
    }
T
Thunderbrook 已提交
110 111 112 113
    hbm_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
      hbm_thread_pool_[i].reset(new ::ThreadPool(1));
    }
T
Thunderbrook 已提交
114 115
  }

116 117
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
Y
yaoxuefeng 已提交
118 119 120
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
121 122 123 124
                  const std::vector<int>& slot_dim,
                  const int hidden_size);
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
T
Thunderbrook 已提交
125 126 127 128
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
                  const int hidden_size);
129 130
  void PushSparseGrad(const paddle::platform::Place& place,
                      const int table_id,
T
Thunderbrook 已提交
131 132 133
                      const std::vector<const uint64_t*>& keys,
                      const std::vector<const float*>& grad_values,
                      const std::vector<int64_t>& slot_lengths,
134 135 136 137 138 139 140
                      const int hidden_size,
                      const int batch_size);
  void CopyKeys(const paddle::platform::Place& place,
                uint64_t** origin_keys,
                uint64_t* total_keys,
                const int64_t* gpu_len,
                int slot_num,
T
Thunderbrook 已提交
141
                int total_len);
142 143
  void CopyForPull(const paddle::platform::Place& place,
                   uint64_t** gpu_keys,
T
Thunderbrook 已提交
144
                   const std::vector<float*>& values,
145 146 147 148
                   const FeatureValue* total_values_gpu,
                   const int64_t* gpu_len,
                   const int slot_num,
                   const int hidden_size,
T
Thunderbrook 已提交
149
                   const int64_t total_length);
150 151
  void CopyForPull(const paddle::platform::Place& place,
                   uint64_t** gpu_keys,
Y
yaoxuefeng 已提交
152
                   const std::vector<float*>& values,
153 154 155 156 157 158
                   const FeatureValue* total_values_gpu,
                   const int64_t* gpu_len,
                   const int slot_num,
                   const int hidden_size,
                   const int64_t total_length,
                   int* gpu_dim);
T
Thunderbrook 已提交
159 160 161 162
  void CopyForPush(const paddle::platform::Place& place,
                   const std::vector<const float*>& grad_values,
                   FeaturePushValue* total_grad_values_gpu,
                   const std::vector<int64_t>& slot_lengths,
163 164
                   const int hidden_size,
                   const int64_t total_length,
T
Thunderbrook 已提交
165
                   const int batch_size);
Y
yaoxuefeng 已提交
166 167 168 169
  void CopyForPush(const paddle::platform::Place& place,
                   const std::vector<const float*>& grad_values,
                   FeaturePushValue* total_grad_values_gpu,
                   const std::vector<int64_t>& slot_lengths,
170 171
                   const uint64_t total_length,
                   const int batch_size,
Y
yaoxuefeng 已提交
172
                   size_t grad_value_size);
T
Thunderbrook 已提交
173

174
  void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
175 176
  void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
  void BuildPull(std::shared_ptr<HeterContext> gpu_task);
177 178 179 180
  void LoadIntoMemory(bool is_shuffle);
  void BeginPass();
  void EndPass();
  void start_build_thread();
181
  void pre_build_thread();
182
  void build_task();
183 184 185 186 187 188 189 190 191 192

  void Finalize() {
    VLOG(3) << "PSGPUWrapper Begin Finalize.";
    if (s_instance_ == nullptr) {
      return;
    }
    data_ready_channel_->Close();
    buildcpu_ready_channel_->Close();
    gpu_free_channel_->Close();
    running_ = false;
193 194
    VLOG(3) << "begin stop pre_build_threads_";
    pre_build_threads_.join();
195 196 197 198
    s_instance_ = nullptr;
    VLOG(3) << "PSGPUWrapper Finalize Finished.";
  }

T
Thunderbrook 已提交
199
  void InitializeGPU(const std::vector<int>& dev_ids) {
200
    if (s_instance_ != NULL && is_initialized_ == false) {
T
Thunderbrook 已提交
201
      VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
202
      is_initialized_ = true;
T
Thunderbrook 已提交
203 204
      resource_ = std::make_shared<HeterPsResource>(dev_ids);
      resource_->enable_p2p();
205
      keys_tensor.resize(resource_->total_device());
Y
yaoxuefeng 已提交
206 207 208 209 210 211 212 213 214
#ifdef PADDLE_WITH_GLOO
      auto gloo = paddle::framework::GlooWrapper::GetInstance();
      if (gloo->Size() > 1) {
        multi_node_ = 1;
      }
#else
      PADDLE_THROW(
          platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
F
Fan Zhang 已提交
215
#ifdef PADDLE_WITH_CUDA
216 217 218 219 220
      if (multi_node_) {
        int dev_size = dev_ids.size();
        // init inner comm
        inner_comms_.resize(dev_size);
        inter_ncclids_.resize(dev_size);
221 222
        platform::dynload::ncclCommInitAll(
            &(inner_comms_[0]), dev_size, &dev_ids[0]);
223 224 225 226 227 228 229 230 231 232
// init inter comm
#ifdef PADDLE_WITH_GLOO
        inter_comms_.resize(dev_size);
        if (gloo->Rank() == 0) {
          for (int i = 0; i < dev_size; ++i) {
            platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
          }
        }

        PADDLE_ENFORCE_EQ(
233 234
            gloo->IsInitialized(),
            true,
235 236 237 238 239 240 241 242
            platform::errors::PreconditionNotMet(
                "You must initialize the gloo environment first to use it."));
        gloo::BroadcastOptions opts(gloo->GetContext());
        opts.setOutput(&inter_ncclids_[0], dev_size);
        opts.setRoot(0);
        gloo::broadcast(opts);

        for (int i = 0; i < dev_size; ++i) {
243 244
          platform::dynload::ncclCommInitRank(
              &inter_comms_[i], gloo->Size(), inter_ncclids_[i], gloo->Rank());
245 246 247 248 249 250 251
        }
        node_size_ = gloo->Size();
#else
        PADDLE_THROW(
            platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
      }
F
Fan Zhang 已提交
252
#endif
Y
yaoxuefeng 已提交
253
      heter_devices_ = dev_ids;
254 255 256 257 258 259 260 261 262 263 264
      data_ready_channel_->Open();
      data_ready_channel_->SetCapacity(3);
      buildcpu_ready_channel_->Open();
      buildcpu_ready_channel_->SetCapacity(3);
      gpu_free_channel_->Open();
      gpu_free_channel_->SetCapacity(1);

      current_task_ = nullptr;
      gpu_free_channel_->Put(current_task_);

      table_id_ = 0;
265

266 267
      // start build cpu&gpu ps thread
      start_build_thread();
T
Thunderbrook 已提交
268 269
    }
  }
Y
yaoxuefeng 已提交
270

271 272 273 274 275 276
  void SetSparseSGD(float nonclk_coeff,
                    float clk_coeff,
                    float min_bound,
                    float max_bound,
                    float learning_rate,
                    float initial_g2sum,
Y
yaoxuefeng 已提交
277
                    float initial_range);
278 279 280 281 282 283
  void SetEmbedxSGD(float mf_create_thresholds,
                    float mf_learning_rate,
                    float mf_initial_g2sum,
                    float mf_initial_range,
                    float mf_min_bound,
                    float mf_max_bound);
Y
yaoxuefeng 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
  void InitializeGPUServer(std::unordered_map<std::string, float> config) {
    float nonclk_coeff = (config.find("nonclk_coeff") == config.end())
                             ? 1.0
                             : config["nonclk_coeff"];
    float clk_coeff =
        (config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"];
    float min_bound = (config.find("min_bound") == config.end())
                          ? -10000.0
                          : config["min_bound"];
    float max_bound = (config.find("max_bound") == config.end())
                          ? 10000.0
                          : config["max_bound"];
    float learning_rate = (config.find("learning_rate") == config.end())
                              ? 1.0
                              : config["learning_rate"];
    float initial_g2sum = (config.find("initial_g2sum") == config.end())
                              ? 1.0
                              : config["initial_g2sum"];
    float initial_range = (config.find("initial_range") == config.end())
                              ? 1.0
                              : config["initial_range"];

    // mf config settings
    float mf_create_thresholds =
        (config.find("mf_create_thresholds") == config.end())
            ? static_cast<float>(1.0)
            : config["mf_create_thresholds"];
    float mf_learning_rate = (config.find("mf_learning_rate") == config.end())
                                 ? 1.0
                                 : config["mf_learning_rate"];
    float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end())
                                 ? 1.0
                                 : config["mf_initial_g2sum"];
    float mf_initial_range = (config.find("mf_initial_range") == config.end())
                                 ? 1.0
                                 : config["mf_initial_range"];
    float mf_min_bound = (config.find("mf_min_bound") == config.end())
                             ? 1.0
                             : config["mf_min_bound"];
    float mf_max_bound = (config.find("mf_max_bound") == config.end())
                             ? 1.0
                             : config["mf_max_bound"];
326 327 328 329 330 331 332 333 334 335 336 337 338
    this->SetSparseSGD(nonclk_coeff,
                       clk_coeff,
                       min_bound,
                       max_bound,
                       learning_rate,
                       initial_g2sum,
                       initial_range);
    this->SetEmbedxSGD(mf_create_thresholds,
                       mf_learning_rate,
                       mf_initial_g2sum,
                       mf_initial_range,
                       mf_min_bound,
                       mf_max_bound);
Y
yaoxuefeng 已提交
339
  }
F
Fan Zhang 已提交
340

341 342 343 344 345 346
  void SetDate(int year, int month, int day) {
    year_ = year;
    month_ = month;
    day_ = day;
  }

Y
yaoxuefeng 已提交
347 348
  void SetDataset(Dataset* dataset) { dataset_ = dataset; }

T
Thunderbrook 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
  // PSGPUWrapper singleton
  static std::shared_ptr<PSGPUWrapper> GetInstance() {
    if (NULL == s_instance_) {
      s_instance_.reset(new paddle::framework::PSGPUWrapper());
    }
    return s_instance_;
  }
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>& GetLocalTable(
      int table_id) {
    return local_tables_[table_id];
  }
  void SetSlotVector(const std::vector<int>& slot_vector) {
    slot_vector_ = slot_vector;
  }

Y
yaoxuefeng 已提交
364 365
  void SetSlotOffsetVector(const std::vector<int>& slot_offset_vector) {
    slot_offset_vector_ = slot_offset_vector;
Y
yaoxuefeng 已提交
366 367 368 369 370
    std::cout << "yxf set: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
Y
yaoxuefeng 已提交
371 372
  }

F
Fan Zhang 已提交
373
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
374 375 376
  void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
    slot_mf_dim_vector_ = slot_mf_dim_vector;
    assert(slot_mf_dim_vector_.size() == slot_vector_.size());
Y
yaoxuefeng 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
  }

  void InitSlotInfo() {
    if (slot_info_initialized_) {
      return;
    }
    SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
    auto slots_vec = dataset->GetSlots();
    slot_offset_vector_.clear();
    for (auto& slot : slot_vector_) {
      for (size_t i = 0; i < slots_vec.size(); ++i) {
        if (std::to_string(slot) == slots_vec[i]) {
          slot_offset_vector_.push_back(i);
          break;
        }
      }
    }
    std::cout << "psgpu wrapper use slots: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
    for (size_t i = 0; i < slot_mf_dim_vector_.size(); i++) {
Y
yaoxuefeng 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
      slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i];
    }

    std::unordered_set<int> dims_set;
    for (auto& it : slot_dim_map_) {
      dims_set.insert(it.second);
    }
    size_t num_of_dim = dims_set.size();
    index_dim_vec_.resize(num_of_dim);
    index_dim_vec_.assign(dims_set.begin(), dims_set.end());
    std::sort(index_dim_vec_.begin(), index_dim_vec_.end());
    std::unordered_map<int, int> dim_index_map;
    for (size_t i = 0; i < num_of_dim; i++) {
      dim_index_map[index_dim_vec_[i]] = i;
    }
415 416
    hbm_pools_.resize(resource_->total_device() * num_of_dim);
    mem_pools_.resize(resource_->total_device() * num_of_dim);
Y
yaoxuefeng 已提交
417 418 419 420 421 422 423 424 425 426 427
    max_mf_dim_ = index_dim_vec_.back();
    multi_mf_dim_ = (dim_index_map.size() >= 1) ? dim_index_map.size() : 0;
    resource_->set_multi_mf(multi_mf_dim_, max_mf_dim_);
    slot_index_vec_.resize(slot_mf_dim_vector_.size());
    for (size_t i = 0; i < slot_index_vec_.size(); i++) {
      slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]];
    }
    val_type_size_ =
        TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
    grad_type_size_ =
        TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
Y
yaoxuefeng 已提交
428
    slot_info_initialized_ = true;
Y
yaoxuefeng 已提交
429
  }
F
Fan Zhang 已提交
430
#endif
Y
yaoxuefeng 已提交
431

T
Thunderbrook 已提交
432 433
  void ShowOneTable(int index) { HeterPs_->show_one_table(index); }

T
Thunderbrook 已提交
434 435 436 437 438 439 440 441
  int UseAfsApi() { return use_afs_api_; }

#ifdef PADDLE_WITH_PSLIB
  std::shared_ptr<paddle::ps::AfsReader> OpenReader(
      const std::string& filename) {
    return afs_handler_.open_reader(filename);
  }

442 443 444 445
  void InitAfsApi(const std::string& fs_name,
                  const std::string& fs_user,
                  const std::string& pass_wd,
                  const std::string& conf);
T
Thunderbrook 已提交
446 447
#endif

T
Thunderbrook 已提交
448 449
 private:
  static std::shared_ptr<PSGPUWrapper> s_instance_;
Y
yaoxuefeng 已提交
450
  Dataset* dataset_;
T
Thunderbrook 已提交
451 452 453
#ifdef PADDLE_WITH_PSLIB
  paddle::ps::AfsApiWrapper afs_handler_;
#endif
T
Thunderbrook 已提交
454
  std::unordered_map<
455 456
      uint64_t,
      std::vector<std::unordered_map<uint64_t, std::vector<float>>>>
T
Thunderbrook 已提交
457 458 459 460 461 462
      local_tables_;
  HeterPsBase* HeterPs_;
  std::vector<LoDTensor> keys_tensor;  // Cache for pull_sparse
  std::shared_ptr<HeterPsResource> resource_;
  int32_t sleep_seconds_before_fail_exit_;
  std::vector<int> slot_vector_;
Y
yaoxuefeng 已提交
463 464 465 466 467 468 469 470 471
  std::vector<int> slot_offset_vector_;
  std::vector<int> slot_mf_dim_vector_;
  std::unordered_map<int, int> slot_dim_map_;
  std::vector<int> slot_index_vec_;
  std::vector<int> index_dim_vec_;
  int multi_mf_dim_{0};
  int max_mf_dim_{0};
  size_t val_type_size_{0};
  size_t grad_type_size_{0};
Y
yaoxuefeng 已提交
472 473 474 475 476 477

  double time_1 = 0.0;
  double time_2 = 0.0;
  double time_3 = 0.0;
  double time_4 = 0.0;

T
Thunderbrook 已提交
478
  int multi_node_{0};
479
  int node_size_;
480
  uint64_t table_id_;
F
Fan Zhang 已提交
481
#ifdef PADDLE_WITH_CUDA
482 483 484
  std::vector<ncclComm_t> inner_comms_;
  std::vector<ncclComm_t> inter_comms_;
  std::vector<ncclUniqueId> inter_ncclids_;
F
Fan Zhang 已提交
485
#endif
Y
yaoxuefeng 已提交
486 487 488
  std::vector<int> heter_devices_;
  std::unordered_set<std::string> gpu_ps_config_keys_;
  HeterObjectPool<HeterContext> gpu_task_pool_;
489
  std::vector<std::vector<robin_hood::unordered_set<uint64_t>>> thread_keys_;
490 491
  std::vector<std::vector<std::vector<robin_hood::unordered_set<uint64_t>>>>
      thread_dim_keys_;
Y
yaoxuefeng 已提交
492 493 494
  int thread_keys_thread_num_ = 37;
  int thread_keys_shard_num_ = 37;
  uint64_t max_fea_num_per_pass_ = 5000000000;
495 496 497
  int year_;
  int month_;
  int day_;
Y
yaoxuefeng 已提交
498
  bool slot_info_initialized_ = false;
T
Thunderbrook 已提交
499
  int use_afs_api_ = 0;
T
Thunderbrook 已提交
500

F
Fan Zhang 已提交
501
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
502 503 504
  std::vector<MemoryPool*> mem_pools_;
  std::vector<HBMMemoryPool*> hbm_pools_;  // in multi mfdim, one table need hbm
                                           // pools of totol dims number
F
Fan Zhang 已提交
505
#endif
Y
yaoxuefeng 已提交
506

507 508 509 510 511 512 513 514 515 516 517 518 519
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      data_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      buildcpu_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      gpu_free_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<HeterContext> current_task_ = nullptr;
520
  std::thread pre_build_threads_;
521
  bool running_ = false;
Y
yaoxuefeng 已提交
522
  std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
T
Thunderbrook 已提交
523
  std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
524

T
Thunderbrook 已提交
525 526 527 528 529 530 531
 protected:
  static bool is_initialized_;
};

}  // end namespace framework
}  // end namespace paddle
#endif