ps_gpu_wrapper.h 30.3 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
T
Thunderbrook 已提交
16
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
17 18 19 20 21

#include <atomic>
#include <ctime>
#include <map>
#include <memory>
22
#include <mutex>
T
Thunderbrook 已提交
23 24 25
#include <random>
#include <string>
#include <unordered_map>
Y
yaoxuefeng 已提交
26
#include <unordered_set>
27
#include <utility>
T
Thunderbrook 已提交
28
#include <vector>
29 30
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
31

Y
yaoxuefeng 已提交
32
#include "paddle/fluid/framework/data_set.h"
33 34
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
35
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
F
Fan Zhang 已提交
36
#include "paddle/fluid/framework/channel.h"
T
Thunderbrook 已提交
37
#include "paddle/fluid/framework/fleet/heter_context.h"
L
lxsbupt 已提交
38
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
T
Thunderbrook 已提交
39 40
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
F
Fan Zhang 已提交
41 42
#include "paddle/fluid/framework/heter_util.h"
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
43
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
F
Fan Zhang 已提交
44 45 46 47 48 49
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
T
Thunderbrook 已提交
50 51 52 53 54
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
55
#ifdef PADDLE_WITH_PSCORE
D
danleifeng 已提交
56 57
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
58
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
D
danleifeng 已提交
59
#include "paddle/fluid/distributed/the_one_ps.pb.h"
T
Thunderbrook 已提交
60
#endif
T
Thunderbrook 已提交
61
#ifdef PADDLE_WITH_PSLIB
Z
zmxdream 已提交
62
#include "afs_api.h"  // NOLINT
T
Thunderbrook 已提交
63
#endif
Y
yaoxuefeng 已提交
64 65 66
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h"  // NOLINT
#endif
67
#include "paddle/fluid/framework/fleet/heter_ps/log_patch.h"
L
lxsbupt 已提交
68
DECLARE_int32(gpugraph_storage_mode);
T
Thunderbrook 已提交
69 70 71 72

namespace paddle {
namespace framework {

F
Fan Zhang 已提交
73 74
class Dataset;

T
Thunderbrook 已提交
75 76 77 78 79
#ifdef PADDLE_WITH_PSLIB
class AfsWrapper {
 public:
  AfsWrapper() {}
  virtual ~AfsWrapper() {}
80 81 82 83
  void init(const std::string& fs_name,
            const std::string& fs_user,
            const std::string& pass_wd,
            const std::string& conf);
T
Thunderbrook 已提交
84 85 86 87 88 89 90 91 92
  int remove(const std::string& path);
  int mkdir(const std::string& path);
  std::vector<std::string> list(const std::string& path);

  int exist(const std::string& path);
  int upload(const std::string& local_file, const std::string& afs_file);

  int download(const std::string& local_file, const std::string& afs_file);

93 94 95 96
  int touchz(const std::string& path);
  std::string cat(const std::string& path);
  int mv(const std::string& old_path, const std::string& dest_path);

T
Thunderbrook 已提交
97 98 99 100 101
 private:
  paddle::ps::AfsApiWrapper afs_handler_;
};
#endif

L
lxsbupt 已提交
102 103 104 105 106 107 108 109 110
struct task_info {
  std::shared_ptr<char> build_values;
  size_t offset;
  int device_id;
  int multi_mf_dim;
  int start;
  int end;
};

T
Thunderbrook 已提交
111
class PSGPUWrapper {
D
danleifeng 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
  class DCacheBuffer {
   public:
    DCacheBuffer() : buf_(nullptr) {}
    ~DCacheBuffer() {}
    /**
     * @Brief get data
     */
    template <typename T>
    T* mutable_data(const size_t total_bytes,
                    const paddle::platform::Place& place) {
      if (buf_ == nullptr) {
        buf_ = memory::AllocShared(place, total_bytes);
      } else if (buf_->size() < total_bytes) {
        buf_.reset();
        buf_ = memory::AllocShared(place, total_bytes);
      }
      return reinterpret_cast<T*>(buf_->ptr());
    }
    template <typename T>
    T* data() {
      return reinterpret_cast<T*>(buf_->ptr());
    }
    size_t memory_size() {
      if (buf_ == nullptr) {
        return 0;
      }
      return buf_->size();
    }
    bool IsInitialized(void) { return (buf_ != nullptr); }

   private:
    std::shared_ptr<memory::Allocation> buf_ = nullptr;
  };
  struct PSDeviceData {
    DCacheBuffer keys_tensor;
    DCacheBuffer dims_tensor;
    DCacheBuffer keys_ptr_tensor;
    DCacheBuffer values_ptr_tensor;
    DCacheBuffer pull_push_tensor;

    DCacheBuffer slot_lens;
    DCacheBuffer d_slot_vector;
    DCacheBuffer keys2slot;

    int64_t total_key_length = 0;
    int64_t dedup_key_length = 0;
  };
  PSDeviceData* device_caches_ = nullptr;

T
Thunderbrook 已提交
161
 public:
D
danleifeng 已提交
162
  ~PSGPUWrapper();
T
Thunderbrook 已提交
163 164 165 166 167 168

  PSGPUWrapper() {
    HeterPs_ = NULL;
    sleep_seconds_before_fail_exit_ = 300;
  }

169 170
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
Y
yaoxuefeng 已提交
171 172 173
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
174 175 176 177
                  const std::vector<int>& slot_dim,
                  const int hidden_size);
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
T
Thunderbrook 已提交
178 179 180 181
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
                  const int hidden_size);
182 183
  void PushSparseGrad(const paddle::platform::Place& place,
                      const int table_id,
T
Thunderbrook 已提交
184 185 186
                      const std::vector<const uint64_t*>& keys,
                      const std::vector<const float*>& grad_values,
                      const std::vector<int64_t>& slot_lengths,
187 188 189 190 191 192 193
                      const int hidden_size,
                      const int batch_size);
  void CopyKeys(const paddle::platform::Place& place,
                uint64_t** origin_keys,
                uint64_t* total_keys,
                const int64_t* gpu_len,
                int slot_num,
T
Thunderbrook 已提交
194
                int total_len);
D
danleifeng 已提交
195 196 197 198 199 200 201
  void CopyKeys(const paddle::platform::Place& place,
                uint64_t** origin_keys,
                uint64_t* total_keys,
                const int64_t* gpu_len,
                int slot_num,
                int total_len,
                int* key2slot);
T
Thunderbrook 已提交
202

L
lxsbupt 已提交
203 204
  void divide_to_device(std::shared_ptr<HeterContext> gpu_task);
  void add_slot_feature(std::shared_ptr<HeterContext> gpu_task);
205
  void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
206 207
  void PreBuildTask(std::shared_ptr<HeterContext> gpu_task,
                    Dataset* dataset_for_pull);
208
  void BuildPull(std::shared_ptr<HeterContext> gpu_task);
L
lxsbupt 已提交
209
  void PrepareGPUTask(std::shared_ptr<HeterContext> gpu_task);
210 211 212
  void LoadIntoMemory(bool is_shuffle);
  void BeginPass();
  void EndPass();
L
lxsbupt 已提交
213 214 215 216 217
  void add_key_to_local(const std::vector<uint64_t>& keys);
  void add_key_to_gputask(std::shared_ptr<HeterContext> gpu_task);
  void resize_gputask(std::shared_ptr<HeterContext> gpu_task);
  void SparseTableToHbm();
  void HbmToSparseTable();
218
  void start_build_thread();
219
  void pre_build_thread();
L
lxsbupt 已提交
220
  void build_pull_thread();
221
  void build_task();
L
lxsbupt 已提交
222
  void DumpToMem();
223 224 225 226
  void MergePull(std::shared_ptr<HeterContext> gpu_task);
  void FilterPull(std::shared_ptr<HeterContext> gpu_task,
                  const int shard_id,
                  const int dim_id);
L
lxsbupt 已提交
227 228 229 230 231 232 233 234
  // set mode
  void SetMode(bool infer_mode) {
    infer_mode_ = infer_mode;
    if (HeterPs_ != NULL) {
      HeterPs_->set_mode(infer_mode);
    }
    VLOG(0) << "set infer mode=" << infer_mode;
  }
235 236 237 238 239 240

  void Finalize() {
    VLOG(3) << "PSGPUWrapper Begin Finalize.";
    if (s_instance_ == nullptr) {
      return;
    }
W
wangzhen38 已提交
241
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
L
lxsbupt 已提交
242 243 244
    if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::WHOLE_HBM) {
      this->EndPass();
    }
W
wangzhen38 已提交
245
#endif
L
lxsbupt 已提交
246 247 248
    for (size_t i = 0; i < hbm_pools_.size(); i++) {
      delete hbm_pools_[i];
    }
249 250
    data_ready_channel_->Close();
    buildcpu_ready_channel_->Close();
L
lxsbupt 已提交
251
    buildpull_ready_channel_->Close();
252
    running_ = false;
253 254
    VLOG(3) << "begin stop pre_build_threads_";
    pre_build_threads_.join();
L
lxsbupt 已提交
255 256
    VLOG(3) << "begin stop buildpull_threads_";
    buildpull_threads_.join();
257 258
    s_instance_ = nullptr;
    VLOG(3) << "PSGPUWrapper Finalize Finished.";
L
lxsbupt 已提交
259
    if (HeterPs_ != NULL) {
260
      HeterPs_->show_table_collisions();
L
lxsbupt 已提交
261 262 263
      delete HeterPs_;
      HeterPs_ = NULL;
    }
D
danleifeng 已提交
264 265 266 267
    if (device_caches_ != nullptr) {
      delete[] device_caches_;
      device_caches_ = nullptr;
    }
268 269
  }

T
Thunderbrook 已提交
270
  void InitializeGPU(const std::vector<int>& dev_ids) {
271
    if (s_instance_ != NULL && is_initialized_ == false) {
T
Thunderbrook 已提交
272
      VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
273
      is_initialized_ = true;
T
Thunderbrook 已提交
274 275
      resource_ = std::make_shared<HeterPsResource>(dev_ids);
      resource_->enable_p2p();
276
      keys_tensor.resize(resource_->total_device());
D
danleifeng 已提交
277
      device_caches_ = new PSDeviceData[resource_->total_device()];
Y
yaoxuefeng 已提交
278 279 280 281
#ifdef PADDLE_WITH_GLOO
      auto gloo = paddle::framework::GlooWrapper::GetInstance();
      if (gloo->Size() > 1) {
        multi_node_ = 1;
L
lxsbupt 已提交
282 283 284 285 286 287
        resource_->set_multi_node(multi_node_);
        optimizer_config_.multi_node = true;
        VLOG(0) << "init multi node gpu server";
      } else {
        optimizer_config_.multi_node = false;
        VLOG(0) << "init single node gpu server";
Y
yaoxuefeng 已提交
288 289 290 291 292
      }
#else
      PADDLE_THROW(
          platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
F
Fan Zhang 已提交
293
#ifdef PADDLE_WITH_CUDA
294 295 296 297 298
      if (multi_node_) {
        int dev_size = dev_ids.size();
        // init inner comm
        inner_comms_.resize(dev_size);
        inter_ncclids_.resize(dev_size);
299 300
        platform::dynload::ncclCommInitAll(
            &(inner_comms_[0]), dev_size, &dev_ids[0]);
301 302 303 304 305 306 307 308 309 310
// init inter comm
#ifdef PADDLE_WITH_GLOO
        inter_comms_.resize(dev_size);
        if (gloo->Rank() == 0) {
          for (int i = 0; i < dev_size; ++i) {
            platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
          }
        }

        PADDLE_ENFORCE_EQ(
311 312
            gloo->IsInitialized(),
            true,
313 314 315 316 317 318 319
            platform::errors::PreconditionNotMet(
                "You must initialize the gloo environment first to use it."));
        gloo::BroadcastOptions opts(gloo->GetContext());
        opts.setOutput(&inter_ncclids_[0], dev_size);
        opts.setRoot(0);
        gloo::broadcast(opts);

L
lxsbupt 已提交
320
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
321
        for (int i = 0; i < dev_size; ++i) {
L
lxsbupt 已提交
322
          platform::CUDADeviceGuard guard(dev_ids[i]);
323 324
          platform::dynload::ncclCommInitRank(
              &inter_comms_[i], gloo->Size(), inter_ncclids_[i], gloo->Rank());
325
        }
L
lxsbupt 已提交
326 327 328
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());

        rank_id_ = gloo->Rank();
329 330 331 332 333 334
        node_size_ = gloo->Size();
#else
        PADDLE_THROW(
            platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
      }
F
Fan Zhang 已提交
335
#endif
Y
yaoxuefeng 已提交
336
      heter_devices_ = dev_ids;
337 338 339 340
      data_ready_channel_->Open();
      data_ready_channel_->SetCapacity(3);
      buildcpu_ready_channel_->Open();
      buildcpu_ready_channel_->SetCapacity(3);
L
lxsbupt 已提交
341 342
      buildpull_ready_channel_->Open();
      buildpull_ready_channel_->SetCapacity(1);
343

L
lxsbupt 已提交
344 345 346 347 348
      cpu_reday_channels_.resize(dev_ids.size());
      for (size_t i = 0; i < dev_ids.size(); i++) {
        cpu_reday_channels_[i] = paddle::framework::MakeChannel<task_info>();
        cpu_reday_channels_[i]->SetCapacity(16);
      }
349 350 351
      current_task_ = nullptr;

      table_id_ = 0;
352
      device_num_ = static_cast<int>(heter_devices_.size());
353

354 355
      // start build cpu&gpu ps thread
      start_build_thread();
T
Thunderbrook 已提交
356 357
    }
  }
Y
yaoxuefeng 已提交
358

359 360 361 362 363 364
  void SetSparseSGD(float nonclk_coeff,
                    float clk_coeff,
                    float min_bound,
                    float max_bound,
                    float learning_rate,
                    float initial_g2sum,
D
danleifeng 已提交
365 366 367 368
                    float initial_range,
                    float beta1_decay_rate,
                    float beta2_decay_rate,
                    float ada_epsilon);
369 370 371 372 373
  void SetEmbedxSGD(float mf_create_thresholds,
                    float mf_learning_rate,
                    float mf_initial_g2sum,
                    float mf_initial_range,
                    float mf_min_bound,
D
danleifeng 已提交
374 375 376
                    float mf_max_bound,
                    float mf_beta1_decay_rate,
                    float mf_beta2_decay_rate,
D
danleifeng 已提交
377 378 379
                    float mf_ada_epsilon,
                    float nodeid_slot,
                    float feature_learning_rate);
D
danleifeng 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430

#ifdef PADDLE_WITH_PSCORE
  void add_sparse_optimizer(
      std::unordered_map<std::string, float>& config,  // NOLINT
      const ::paddle::distributed::SparseCommonSGDRuleParameter& sgd_param,
      const std::string& prefix = "") {
    auto optimizer_name = sgd_param.name();
    if (optimizer_name == "SparseNaiveSGDRule") {
      config[prefix + "optimizer_type"] = 0;
      config[prefix + "learning_rate"] = sgd_param.naive().learning_rate();
      config[prefix + "initial_range"] = sgd_param.naive().initial_range();
      config[prefix + "min_bound"] = sgd_param.naive().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.naive().weight_bounds()[1];
    } else if (optimizer_name == "SparseAdaGradSGDRule") {
      config[prefix + "optimizer_type"] = 1;
      config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adagrad().initial_range();
      config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum();
      config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1];
    } else if (optimizer_name == "StdAdaGradSGDRule") {
      config[prefix + "optimizer_type"] = 2;
      config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adagrad().initial_range();
      config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum();
      config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1];
    } else if (optimizer_name == "SparseAdamSGDRule") {
      config[prefix + "optimizer_type"] = 3;
      config[prefix + "learning_rate"] = sgd_param.adam().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adam().initial_range();
      config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate();
      config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate();
      config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon();
      config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1];
    } else if (optimizer_name == "SparseSharedAdamSGDRule") {
      config[prefix + "optimizer_type"] = 4;
      config[prefix + "learning_rate"] = sgd_param.adam().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adam().initial_range();
      config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate();
      config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate();
      config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon();
      config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1];
    }
  }

  void InitializeGPUServer(paddle::distributed::PSParameter ps_param) {
    auto sparse_table =
        ps_param.server_param().downpour_server_param().downpour_table_param(0);
D
danleifeng 已提交
431 432 433 434 435 436 437 438 439 440
    // set build thread_num and shard_num
    thread_keys_thread_num_ = sparse_table.shard_num();
    thread_keys_shard_num_ = sparse_table.shard_num();
    VLOG(1) << "ps_gpu build phase thread_num:" << thread_keys_thread_num_
            << " shard_num:" << thread_keys_shard_num_;

    pull_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
      pull_thread_pool_[i].reset(new ::ThreadPool(1));
    }
441
    hbm_thread_pool_.resize(device_num_);
D
danleifeng 已提交
442 443 444
    for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
      hbm_thread_pool_[i].reset(new ::ThreadPool(1));
    }
445 446 447
    cpu_work_pool_.resize(device_num_);
    for (size_t i = 0; i < cpu_work_pool_.size(); i++) {
      cpu_work_pool_[i].reset(new ::ThreadPool(cpu_device_thread_num_));
L
lxsbupt 已提交
448 449
    }

D
danleifeng 已提交
450 451 452 453 454 455 456 457 458 459 460
    auto sparse_table_accessor = sparse_table.accessor();
    auto sparse_table_accessor_parameter =
        sparse_table_accessor.ctr_accessor_param();
    accessor_class_ = sparse_table_accessor.accessor_class();

    std::unordered_map<std::string, float> config;
    config["embedx_dim"] = sparse_table_accessor.embedx_dim();
    config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff();
    config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff();
    config["mf_create_thresholds"] = sparse_table_accessor.embedx_threshold();

D
danleifeng 已提交
461 462 463 464 465
    config["nodeid_slot"] =
        sparse_table_accessor.graph_sgd_param().nodeid_slot();
    config["feature_learning_rate"] =
        sparse_table_accessor.graph_sgd_param().feature_learning_rate();

D
danleifeng 已提交
466 467 468 469 470 471 472 473
    if (accessor_class_ == "CtrDymfAccessor") {
      // optimizer config for embed_w and embedx
      add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param());
      add_sparse_optimizer(
          config, sparse_table_accessor.embedx_sgd_param(), "mf_");
    }

    fleet_config_ = config;
D
danleifeng 已提交
474 475
    GlobalAccessorFactory::GetInstance().Init(accessor_class_);
    GlobalAccessorFactory::GetInstance().GetAccessorWrapper()->Configure(
D
danleifeng 已提交
476 477 478 479 480
        config);
    InitializeGPUServer(config);
  }
#endif

Y
yaoxuefeng 已提交
481 482 483 484 485 486 487
  void InitializeGPUServer(std::unordered_map<std::string, float> config) {
    float nonclk_coeff = (config.find("nonclk_coeff") == config.end())
                             ? 1.0
                             : config["nonclk_coeff"];
    float clk_coeff =
        (config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"];
    float min_bound = (config.find("min_bound") == config.end())
D
danleifeng 已提交
488
                          ? -10.0
Y
yaoxuefeng 已提交
489
                          : config["min_bound"];
D
danleifeng 已提交
490 491
    float max_bound =
        (config.find("max_bound") == config.end()) ? 10.0 : config["max_bound"];
Y
yaoxuefeng 已提交
492
    float learning_rate = (config.find("learning_rate") == config.end())
D
danleifeng 已提交
493
                              ? 0.05
Y
yaoxuefeng 已提交
494 495
                              : config["learning_rate"];
    float initial_g2sum = (config.find("initial_g2sum") == config.end())
D
danleifeng 已提交
496
                              ? 3.0
Y
yaoxuefeng 已提交
497 498
                              : config["initial_g2sum"];
    float initial_range = (config.find("initial_range") == config.end())
D
danleifeng 已提交
499
                              ? 1e-4
Y
yaoxuefeng 已提交
500
                              : config["initial_range"];
D
danleifeng 已提交
501 502 503 504 505 506 507 508 509
    float beta1_decay_rate = (config.find("beta1_decay_rate") == config.end())
                                 ? 0.9
                                 : config["beta1_decay_rate"];
    float beta2_decay_rate = (config.find("beta2_decay_rate") == config.end())
                                 ? 0.999
                                 : config["beta2_decay_rate"];
    float ada_epsilon = (config.find("ada_epsilon") == config.end())
                            ? 1e-8
                            : config["ada_epsilon"];
Y
yaoxuefeng 已提交
510 511 512 513 514 515
    // mf config settings
    float mf_create_thresholds =
        (config.find("mf_create_thresholds") == config.end())
            ? static_cast<float>(1.0)
            : config["mf_create_thresholds"];
    float mf_learning_rate = (config.find("mf_learning_rate") == config.end())
D
danleifeng 已提交
516
                                 ? 0.05
Y
yaoxuefeng 已提交
517 518
                                 : config["mf_learning_rate"];
    float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end())
D
danleifeng 已提交
519
                                 ? 3.0
Y
yaoxuefeng 已提交
520 521
                                 : config["mf_initial_g2sum"];
    float mf_initial_range = (config.find("mf_initial_range") == config.end())
D
danleifeng 已提交
522
                                 ? 1e-4
Y
yaoxuefeng 已提交
523 524
                                 : config["mf_initial_range"];
    float mf_min_bound = (config.find("mf_min_bound") == config.end())
D
danleifeng 已提交
525
                             ? -10.0
Y
yaoxuefeng 已提交
526 527
                             : config["mf_min_bound"];
    float mf_max_bound = (config.find("mf_max_bound") == config.end())
D
danleifeng 已提交
528
                             ? 10.0
Y
yaoxuefeng 已提交
529
                             : config["mf_max_bound"];
D
danleifeng 已提交
530 531 532 533 534 535 536 537 538 539 540
    float mf_beta1_decay_rate =
        (config.find("mf_beta1_decay_rate") == config.end())
            ? 0.9
            : config["mf_beta1_decay_rate"];
    float mf_beta2_decay_rate =
        (config.find("mf_beta2_decay_rate") == config.end())
            ? 0.999
            : config["mf_beta2_decay_rate"];
    float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end())
                               ? 1e-8
                               : config["mf_ada_epsilon"];
D
danleifeng 已提交
541 542 543 544 545 546 547 548 549 550

    float feature_learning_rate =
        (config.find("feature_learning_rate") == config.end())
            ? 0.05
            : config["feature_learning_rate"];

    float nodeid_slot = (config.find("nodeid_slot") == config.end())
                            ? 9008
                            : config["nodeid_slot"];

551 552 553 554 555 556
    this->SetSparseSGD(nonclk_coeff,
                       clk_coeff,
                       min_bound,
                       max_bound,
                       learning_rate,
                       initial_g2sum,
D
danleifeng 已提交
557 558 559 560
                       initial_range,
                       beta1_decay_rate,
                       beta2_decay_rate,
                       ada_epsilon);
561 562 563 564 565
    this->SetEmbedxSGD(mf_create_thresholds,
                       mf_learning_rate,
                       mf_initial_g2sum,
                       mf_initial_range,
                       mf_min_bound,
D
danleifeng 已提交
566 567 568
                       mf_max_bound,
                       mf_beta1_decay_rate,
                       mf_beta2_decay_rate,
D
danleifeng 已提交
569 570 571
                       mf_ada_epsilon,
                       nodeid_slot,
                       feature_learning_rate);
D
danleifeng 已提交
572 573 574 575

    // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam)
    optimizer_type_ = (config.find("optimizer_type") == config.end())
                          ? 1
Z
zmxdream 已提交
576
                          : static_cast<int>(config["optimizer_type"]);
D
danleifeng 已提交
577 578 579 580

    VLOG(0) << "InitializeGPUServer optimizer_type_:" << optimizer_type_
            << " nodeid_slot:" << nodeid_slot
            << " feature_learning_rate:" << feature_learning_rate;
Y
yaoxuefeng 已提交
581
  }
F
Fan Zhang 已提交
582

583 584 585 586 587 588
  void SetDate(int year, int month, int day) {
    year_ = year;
    month_ = month;
    day_ = day;
  }

Y
yaoxuefeng 已提交
589 590
  void SetDataset(Dataset* dataset) { dataset_ = dataset; }

T
Thunderbrook 已提交
591 592
  // PSGPUWrapper singleton
  static std::shared_ptr<PSGPUWrapper> GetInstance() {
593 594 595 596 597
    {
      std::lock_guard<std::mutex> lk(ins_mutex);
      if (NULL == s_instance_) {
        s_instance_.reset(new paddle::framework::PSGPUWrapper());
      }
T
Thunderbrook 已提交
598 599 600 601 602 603 604 605 606
    }
    return s_instance_;
  }
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>& GetLocalTable(
      int table_id) {
    return local_tables_[table_id];
  }
  void SetSlotVector(const std::vector<int>& slot_vector) {
    slot_vector_ = slot_vector;
L
lxsbupt 已提交
607
    VLOG(0) << "slot_vector size is " << slot_vector_.size();
T
Thunderbrook 已提交
608
  }
609 610 611 612
  void SetPullFeatureSlotNum(int slot_num) {
    slot_num_for_pull_feature_ = slot_num;
    VLOG(0) << "slot_num_for_pull_feature_ is " << slot_num_for_pull_feature_;
  }
Y
yaoxuefeng 已提交
613 614
  void SetSlotOffsetVector(const std::vector<int>& slot_offset_vector) {
    slot_offset_vector_ = slot_offset_vector;
Y
yaoxuefeng 已提交
615 616 617 618 619
    std::cout << "yxf set: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
Y
yaoxuefeng 已提交
620 621
  }

F
Fan Zhang 已提交
622
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
623 624 625
  void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
    slot_mf_dim_vector_ = slot_mf_dim_vector;
    assert(slot_mf_dim_vector_.size() == slot_vector_.size());
Y
yaoxuefeng 已提交
626 627 628 629 630 631
  }

  void InitSlotInfo() {
    if (slot_info_initialized_) {
      return;
    }
Z
zmxdream 已提交
632
    SlotRecordDataset* dataset = reinterpret_cast<SlotRecordDataset*>(dataset_);
Y
yaoxuefeng 已提交
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
    auto slots_vec = dataset->GetSlots();
    slot_offset_vector_.clear();
    for (auto& slot : slot_vector_) {
      for (size_t i = 0; i < slots_vec.size(); ++i) {
        if (std::to_string(slot) == slots_vec[i]) {
          slot_offset_vector_.push_back(i);
          break;
        }
      }
    }
    std::cout << "psgpu wrapper use slots: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
    for (size_t i = 0; i < slot_mf_dim_vector_.size(); i++) {
Y
yaoxuefeng 已提交
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
      slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i];
    }

    std::unordered_set<int> dims_set;
    for (auto& it : slot_dim_map_) {
      dims_set.insert(it.second);
    }
    size_t num_of_dim = dims_set.size();
    index_dim_vec_.resize(num_of_dim);
    index_dim_vec_.assign(dims_set.begin(), dims_set.end());
    std::sort(index_dim_vec_.begin(), index_dim_vec_.end());
    std::unordered_map<int, int> dim_index_map;
    for (size_t i = 0; i < num_of_dim; i++) {
      dim_index_map[index_dim_vec_[i]] = i;
    }
664
    hbm_pools_.resize(resource_->total_device() * num_of_dim);
L
lxsbupt 已提交
665 666 667 668
    for (size_t i = 0; i < hbm_pools_.size(); i++) {
      hbm_pools_[i] = new HBMMemoryPoolFix();
    }

669
    mem_pools_.resize(resource_->total_device() * num_of_dim);
Y
yaoxuefeng 已提交
670 671 672 673 674 675 676
    max_mf_dim_ = index_dim_vec_.back();
    multi_mf_dim_ = (dim_index_map.size() >= 1) ? dim_index_map.size() : 0;
    resource_->set_multi_mf(multi_mf_dim_, max_mf_dim_);
    slot_index_vec_.resize(slot_mf_dim_vector_.size());
    for (size_t i = 0; i < slot_index_vec_.size(); i++) {
      slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]];
    }
D
danleifeng 已提交
677 678

    auto accessor_wrapper_ptr =
D
danleifeng 已提交
679
        GlobalAccessorFactory::GetInstance().GetAccessorWrapper();
D
danleifeng 已提交
680 681
    val_type_size_ = accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_);
    grad_type_size_ = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_);
D
danleifeng 已提交
682
    pull_type_size_ = accessor_wrapper_ptr->GetPullValueSize(max_mf_dim_);
D
danleifeng 已提交
683
    VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_
D
danleifeng 已提交
684 685
            << " grad_type_size_:" << grad_type_size_
            << " pull_type_size_:" << pull_type_size_;
Y
yaoxuefeng 已提交
686
    slot_info_initialized_ = true;
Y
yaoxuefeng 已提交
687
  }
F
Fan Zhang 已提交
688
#endif
Y
yaoxuefeng 已提交
689

T
Thunderbrook 已提交
690 691
  void ShowOneTable(int index) { HeterPs_->show_one_table(index); }

T
Thunderbrook 已提交
692 693 694 695 696 697 698 699
  int UseAfsApi() { return use_afs_api_; }

#ifdef PADDLE_WITH_PSLIB
  std::shared_ptr<paddle::ps::AfsReader> OpenReader(
      const std::string& filename) {
    return afs_handler_.open_reader(filename);
  }

Z
zmxdream 已提交
700 701 702 703 704
  std::shared_ptr<paddle::ps::AfsWriter> OpenWriter(
      const std::string& filename) {
    return afs_handler_.open_writer(filename);
  }

705 706 707 708
  void InitAfsApi(const std::string& fs_name,
                  const std::string& fs_user,
                  const std::string& pass_wd,
                  const std::string& conf);
T
Thunderbrook 已提交
709 710
#endif

D
danleifeng 已提交
711 712 713 714 715
#ifdef PADDLE_WITH_PSCORE
  void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) {
    cpu_table_accessor_ = accessor;
  }
#endif
716 717 718 719
  // for node rank
  int PartitionKeyForRank(const uint64_t& key) {
    return ((key / device_num_) % node_size_);
  }
D
danleifeng 已提交
720

T
Thunderbrook 已提交
721 722
 private:
  static std::shared_ptr<PSGPUWrapper> s_instance_;
723
  static std::mutex ins_mutex;
Y
yaoxuefeng 已提交
724
  Dataset* dataset_;
T
Thunderbrook 已提交
725 726 727
#ifdef PADDLE_WITH_PSLIB
  paddle::ps::AfsApiWrapper afs_handler_;
#endif
T
Thunderbrook 已提交
728
  std::unordered_map<
729 730
      uint64_t,
      std::vector<std::unordered_map<uint64_t, std::vector<float>>>>
T
Thunderbrook 已提交
731
      local_tables_;
L
lxsbupt 已提交
732 733
  HeterPsBase* HeterPs_ = NULL;
  // std::vector<LoDTensor> keys_tensor;  // Cache for pull_sparse
734
  std::vector<phi::DenseTensor> keys_tensor;  // Cache for pull_sparse
T
Thunderbrook 已提交
735 736 737
  std::shared_ptr<HeterPsResource> resource_;
  int32_t sleep_seconds_before_fail_exit_;
  std::vector<int> slot_vector_;
Y
yaoxuefeng 已提交
738 739 740 741 742 743 744
  std::vector<int> slot_offset_vector_;
  std::vector<int> slot_mf_dim_vector_;
  std::unordered_map<int, int> slot_dim_map_;
  std::vector<int> slot_index_vec_;
  std::vector<int> index_dim_vec_;
  int multi_mf_dim_{0};
  int max_mf_dim_{0};
745
  int slot_num_for_pull_feature_{0};
Y
yaoxuefeng 已提交
746 747
  size_t val_type_size_{0};
  size_t grad_type_size_{0};
D
danleifeng 已提交
748
  size_t pull_type_size_{0};
Y
yaoxuefeng 已提交
749 750 751 752 753 754

  double time_1 = 0.0;
  double time_2 = 0.0;
  double time_3 = 0.0;
  double time_4 = 0.0;

T
Thunderbrook 已提交
755
  int multi_node_{0};
L
lxsbupt 已提交
756
  int rank_id_;
757
  int node_size_;
758
  int device_num_ = 8;
759
  uint64_t table_id_;
D
danleifeng 已提交
760
  int gpu_graph_mode_ = 0;
F
Fan Zhang 已提交
761
#ifdef PADDLE_WITH_CUDA
762 763 764
  std::vector<ncclComm_t> inner_comms_;
  std::vector<ncclComm_t> inter_comms_;
  std::vector<ncclUniqueId> inter_ncclids_;
F
Fan Zhang 已提交
765
#endif
Y
yaoxuefeng 已提交
766 767 768
  std::vector<int> heter_devices_;
  std::unordered_set<std::string> gpu_ps_config_keys_;
  HeterObjectPool<HeterContext> gpu_task_pool_;
769
  std::vector<std::vector<robin_hood::unordered_set<uint64_t>>> thread_keys_;
770 771
  std::vector<std::vector<std::vector<robin_hood::unordered_set<uint64_t>>>>
      thread_dim_keys_;
Y
yaoxuefeng 已提交
772 773 774
  int thread_keys_thread_num_ = 37;
  int thread_keys_shard_num_ = 37;
  uint64_t max_fea_num_per_pass_ = 5000000000;
775 776 777
  int year_;
  int month_;
  int day_;
Y
yaoxuefeng 已提交
778
  bool slot_info_initialized_ = false;
L
lxsbupt 已提交
779
  bool hbm_sparse_table_initialized_ = false;
T
Thunderbrook 已提交
780
  int use_afs_api_ = 0;
D
danleifeng 已提交
781 782 783 784 785 786
  int optimizer_type_ = 1;
  std::string accessor_class_;
  std::unordered_map<std::string, float> fleet_config_;
#ifdef PADDLE_WITH_PSCORE
  paddle::distributed::ValueAccessor* cpu_table_accessor_;
#endif
T
Thunderbrook 已提交
787

F
Fan Zhang 已提交
788
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
789
  std::vector<MemoryPool*> mem_pools_;
L
lxsbupt 已提交
790 791
  std::vector<HBMMemoryPoolFix*> hbm_pools_;  // in multi mfdim, one table need
                                              // hbm pools of totol dims number
F
Fan Zhang 已提交
792
#endif
Y
yaoxuefeng 已提交
793

794 795 796 797
  std::shared_ptr<paddle::framework::ChannelObject<
      std::pair<std::shared_ptr<HeterContext>, Dataset*>>>
      data_ready_channel_ = paddle::framework::MakeChannel<
          std::pair<std::shared_ptr<HeterContext>, Dataset*>>();
798 799 800 801
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      buildcpu_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
L
lxsbupt 已提交
802 803 804 805 806 807
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      buildpull_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::vector<std::shared_ptr<paddle::framework::ChannelObject<task_info>>>
      cpu_reday_channels_;
808
  std::shared_ptr<HeterContext> current_task_ = nullptr;
809
  std::thread pre_build_threads_;
L
lxsbupt 已提交
810
  std::thread buildpull_threads_;
811
  bool running_ = false;
812 813 814
  std::vector<std::shared_ptr<::ThreadPool>> pull_thread_pool_;
  std::vector<std::shared_ptr<::ThreadPool>> hbm_thread_pool_;
  std::vector<std::shared_ptr<::ThreadPool>> cpu_work_pool_;
D
danleifeng 已提交
815
  OptimizerConfig optimizer_config_;
L
lxsbupt 已提交
816 817 818 819
  // gradient push count
  uint64_t grad_push_count_ = 0;
  // infer mode
  bool infer_mode_ = false;
820
  size_t cpu_device_thread_num_ = 16;
821

T
Thunderbrook 已提交
822 823 824 825 826 827 828
 protected:
  static bool is_initialized_;
};

}  // end namespace framework
}  // end namespace paddle
#endif