ps_gpu_wrapper.h 29.8 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
T
Thunderbrook 已提交
16
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
17 18 19 20 21

#include <atomic>
#include <ctime>
#include <map>
#include <memory>
22
#include <mutex>
T
Thunderbrook 已提交
23 24 25
#include <random>
#include <string>
#include <unordered_map>
Y
yaoxuefeng 已提交
26
#include <unordered_set>
T
Thunderbrook 已提交
27
#include <vector>
28 29
#ifdef PADDLE_WITH_GLOO
#include <gloo/broadcast.h>
30

Y
yaoxuefeng 已提交
31
#include "paddle/fluid/framework/data_set.h"
32 33
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
34
#include "paddle/fluid/distributed/ps/thirdparty/round_robin.h"
F
Fan Zhang 已提交
35
#include "paddle/fluid/framework/channel.h"
T
Thunderbrook 已提交
36
#include "paddle/fluid/framework/fleet/heter_context.h"
L
lxsbupt 已提交
37
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
T
Thunderbrook 已提交
38 39
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
F
Fan Zhang 已提交
40 41
#include "paddle/fluid/framework/heter_util.h"
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
42
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
F
Fan Zhang 已提交
43 44 45 46 47 48
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
T
Thunderbrook 已提交
49 50 51 52 53
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/place.h"
T
Thunderbrook 已提交
54
#ifdef PADDLE_WITH_PSCORE
D
danleifeng 已提交
55 56
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
57
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
D
danleifeng 已提交
58
#include "paddle/fluid/distributed/the_one_ps.pb.h"
T
Thunderbrook 已提交
59
#endif
T
Thunderbrook 已提交
60
#ifdef PADDLE_WITH_PSLIB
Z
zmxdream 已提交
61
#include "afs_api.h"  // NOLINT
T
Thunderbrook 已提交
62
#endif
Y
yaoxuefeng 已提交
63 64 65
#ifdef PADDLE_WITH_PSLIB
#include "downpour_accessor.h"  // NOLINT
#endif
66
#include "paddle/fluid/framework/fleet/heter_ps/log_patch.h"
L
lxsbupt 已提交
67
DECLARE_int32(gpugraph_storage_mode);
T
Thunderbrook 已提交
68 69 70 71

namespace paddle {
namespace framework {

F
Fan Zhang 已提交
72 73
class Dataset;

T
Thunderbrook 已提交
74 75 76 77 78
#ifdef PADDLE_WITH_PSLIB
class AfsWrapper {
 public:
  AfsWrapper() {}
  virtual ~AfsWrapper() {}
79 80 81 82
  void init(const std::string& fs_name,
            const std::string& fs_user,
            const std::string& pass_wd,
            const std::string& conf);
T
Thunderbrook 已提交
83 84 85 86 87 88 89 90 91
  int remove(const std::string& path);
  int mkdir(const std::string& path);
  std::vector<std::string> list(const std::string& path);

  int exist(const std::string& path);
  int upload(const std::string& local_file, const std::string& afs_file);

  int download(const std::string& local_file, const std::string& afs_file);

92 93 94 95
  int touchz(const std::string& path);
  std::string cat(const std::string& path);
  int mv(const std::string& old_path, const std::string& dest_path);

T
Thunderbrook 已提交
96 97 98 99 100
 private:
  paddle::ps::AfsApiWrapper afs_handler_;
};
#endif

L
lxsbupt 已提交
101 102 103 104 105 106 107 108 109
struct task_info {
  std::shared_ptr<char> build_values;
  size_t offset;
  int device_id;
  int multi_mf_dim;
  int start;
  int end;
};

T
Thunderbrook 已提交
110
class PSGPUWrapper {
D
danleifeng 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  class DCacheBuffer {
   public:
    DCacheBuffer() : buf_(nullptr) {}
    ~DCacheBuffer() {}
    /**
     * @Brief get data
     */
    template <typename T>
    T* mutable_data(const size_t total_bytes,
                    const paddle::platform::Place& place) {
      if (buf_ == nullptr) {
        buf_ = memory::AllocShared(place, total_bytes);
      } else if (buf_->size() < total_bytes) {
        buf_.reset();
        buf_ = memory::AllocShared(place, total_bytes);
      }
      return reinterpret_cast<T*>(buf_->ptr());
    }
    template <typename T>
    T* data() {
      return reinterpret_cast<T*>(buf_->ptr());
    }
    size_t memory_size() {
      if (buf_ == nullptr) {
        return 0;
      }
      return buf_->size();
    }
    bool IsInitialized(void) { return (buf_ != nullptr); }

   private:
    std::shared_ptr<memory::Allocation> buf_ = nullptr;
  };
  struct PSDeviceData {
    DCacheBuffer keys_tensor;
    DCacheBuffer dims_tensor;
    DCacheBuffer keys_ptr_tensor;
    DCacheBuffer values_ptr_tensor;
    DCacheBuffer pull_push_tensor;

    DCacheBuffer slot_lens;
    DCacheBuffer d_slot_vector;
    DCacheBuffer keys2slot;

    int64_t total_key_length = 0;
    int64_t dedup_key_length = 0;
  };
  PSDeviceData* device_caches_ = nullptr;

T
Thunderbrook 已提交
160
 public:
D
danleifeng 已提交
161
  ~PSGPUWrapper();
T
Thunderbrook 已提交
162 163 164 165 166 167

  PSGPUWrapper() {
    HeterPs_ = NULL;
    sleep_seconds_before_fail_exit_ = 300;
  }

168 169
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
Y
yaoxuefeng 已提交
170 171 172
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
173 174 175 176
                  const std::vector<int>& slot_dim,
                  const int hidden_size);
  void PullSparse(const paddle::platform::Place& place,
                  const int table_id,
T
Thunderbrook 已提交
177 178 179 180
                  const std::vector<const uint64_t*>& keys,
                  const std::vector<float*>& values,
                  const std::vector<int64_t>& slot_lengths,
                  const int hidden_size);
181 182
  void PushSparseGrad(const paddle::platform::Place& place,
                      const int table_id,
T
Thunderbrook 已提交
183 184 185
                      const std::vector<const uint64_t*>& keys,
                      const std::vector<const float*>& grad_values,
                      const std::vector<int64_t>& slot_lengths,
186 187 188 189 190 191 192
                      const int hidden_size,
                      const int batch_size);
  void CopyKeys(const paddle::platform::Place& place,
                uint64_t** origin_keys,
                uint64_t* total_keys,
                const int64_t* gpu_len,
                int slot_num,
T
Thunderbrook 已提交
193
                int total_len);
D
danleifeng 已提交
194 195 196 197 198 199 200
  void CopyKeys(const paddle::platform::Place& place,
                uint64_t** origin_keys,
                uint64_t* total_keys,
                const int64_t* gpu_len,
                int slot_num,
                int total_len,
                int* key2slot);
T
Thunderbrook 已提交
201

L
lxsbupt 已提交
202 203
  void divide_to_device(std::shared_ptr<HeterContext> gpu_task);
  void add_slot_feature(std::shared_ptr<HeterContext> gpu_task);
204
  void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
205 206
  void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
  void BuildPull(std::shared_ptr<HeterContext> gpu_task);
L
lxsbupt 已提交
207
  void PrepareGPUTask(std::shared_ptr<HeterContext> gpu_task);
208 209 210
  void LoadIntoMemory(bool is_shuffle);
  void BeginPass();
  void EndPass();
L
lxsbupt 已提交
211 212 213 214 215
  void add_key_to_local(const std::vector<uint64_t>& keys);
  void add_key_to_gputask(std::shared_ptr<HeterContext> gpu_task);
  void resize_gputask(std::shared_ptr<HeterContext> gpu_task);
  void SparseTableToHbm();
  void HbmToSparseTable();
216
  void start_build_thread();
217
  void pre_build_thread();
L
lxsbupt 已提交
218
  void build_pull_thread();
219
  void build_task();
L
lxsbupt 已提交
220 221 222 223 224 225 226 227 228
  void DumpToMem();
  // set mode
  void SetMode(bool infer_mode) {
    infer_mode_ = infer_mode;
    if (HeterPs_ != NULL) {
      HeterPs_->set_mode(infer_mode);
    }
    VLOG(0) << "set infer mode=" << infer_mode;
  }
229 230 231 232 233 234

  void Finalize() {
    VLOG(3) << "PSGPUWrapper Begin Finalize.";
    if (s_instance_ == nullptr) {
      return;
    }
L
lxsbupt 已提交
235 236 237 238 239 240
    if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::WHOLE_HBM) {
      this->EndPass();
    }
    for (size_t i = 0; i < hbm_pools_.size(); i++) {
      delete hbm_pools_[i];
    }
241 242
    data_ready_channel_->Close();
    buildcpu_ready_channel_->Close();
L
lxsbupt 已提交
243
    buildpull_ready_channel_->Close();
244 245
    gpu_free_channel_->Close();
    running_ = false;
246 247
    VLOG(3) << "begin stop pre_build_threads_";
    pre_build_threads_.join();
L
lxsbupt 已提交
248 249
    VLOG(3) << "begin stop buildpull_threads_";
    buildpull_threads_.join();
250 251
    s_instance_ = nullptr;
    VLOG(3) << "PSGPUWrapper Finalize Finished.";
D
danleifeng 已提交
252
    HeterPs_->show_table_collisions();
L
lxsbupt 已提交
253 254 255 256
    if (HeterPs_ != NULL) {
      delete HeterPs_;
      HeterPs_ = NULL;
    }
D
danleifeng 已提交
257 258 259 260
    if (device_caches_ != nullptr) {
      delete[] device_caches_;
      device_caches_ = nullptr;
    }
261 262
  }

T
Thunderbrook 已提交
263
  void InitializeGPU(const std::vector<int>& dev_ids) {
264
    if (s_instance_ != NULL && is_initialized_ == false) {
T
Thunderbrook 已提交
265
      VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
266
      is_initialized_ = true;
T
Thunderbrook 已提交
267 268
      resource_ = std::make_shared<HeterPsResource>(dev_ids);
      resource_->enable_p2p();
269
      keys_tensor.resize(resource_->total_device());
D
danleifeng 已提交
270
      device_caches_ = new PSDeviceData[resource_->total_device()];
Y
yaoxuefeng 已提交
271 272 273 274
#ifdef PADDLE_WITH_GLOO
      auto gloo = paddle::framework::GlooWrapper::GetInstance();
      if (gloo->Size() > 1) {
        multi_node_ = 1;
L
lxsbupt 已提交
275 276 277 278 279 280
        resource_->set_multi_node(multi_node_);
        optimizer_config_.multi_node = true;
        VLOG(0) << "init multi node gpu server";
      } else {
        optimizer_config_.multi_node = false;
        VLOG(0) << "init single node gpu server";
Y
yaoxuefeng 已提交
281 282 283 284 285
      }
#else
      PADDLE_THROW(
          platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
F
Fan Zhang 已提交
286
#ifdef PADDLE_WITH_CUDA
287 288 289 290 291
      if (multi_node_) {
        int dev_size = dev_ids.size();
        // init inner comm
        inner_comms_.resize(dev_size);
        inter_ncclids_.resize(dev_size);
292 293
        platform::dynload::ncclCommInitAll(
            &(inner_comms_[0]), dev_size, &dev_ids[0]);
294 295 296 297 298 299 300 301 302 303
// init inter comm
#ifdef PADDLE_WITH_GLOO
        inter_comms_.resize(dev_size);
        if (gloo->Rank() == 0) {
          for (int i = 0; i < dev_size; ++i) {
            platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
          }
        }

        PADDLE_ENFORCE_EQ(
304 305
            gloo->IsInitialized(),
            true,
306 307 308 309 310 311 312
            platform::errors::PreconditionNotMet(
                "You must initialize the gloo environment first to use it."));
        gloo::BroadcastOptions opts(gloo->GetContext());
        opts.setOutput(&inter_ncclids_[0], dev_size);
        opts.setRoot(0);
        gloo::broadcast(opts);

L
lxsbupt 已提交
313
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
314
        for (int i = 0; i < dev_size; ++i) {
L
lxsbupt 已提交
315
          platform::CUDADeviceGuard guard(dev_ids[i]);
316 317
          platform::dynload::ncclCommInitRank(
              &inter_comms_[i], gloo->Size(), inter_ncclids_[i], gloo->Rank());
318
        }
L
lxsbupt 已提交
319 320 321
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());

        rank_id_ = gloo->Rank();
322 323 324 325 326 327
        node_size_ = gloo->Size();
#else
        PADDLE_THROW(
            platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
      }
F
Fan Zhang 已提交
328
#endif
Y
yaoxuefeng 已提交
329
      heter_devices_ = dev_ids;
330 331 332 333
      data_ready_channel_->Open();
      data_ready_channel_->SetCapacity(3);
      buildcpu_ready_channel_->Open();
      buildcpu_ready_channel_->SetCapacity(3);
L
lxsbupt 已提交
334 335
      buildpull_ready_channel_->Open();
      buildpull_ready_channel_->SetCapacity(1);
336 337 338
      gpu_free_channel_->Open();
      gpu_free_channel_->SetCapacity(1);

L
lxsbupt 已提交
339 340 341 342 343 344
      cpu_reday_channels_.resize(dev_ids.size());
      for (size_t i = 0; i < dev_ids.size(); i++) {
        cpu_reday_channels_[i] = paddle::framework::MakeChannel<task_info>();
        cpu_reday_channels_[i]->SetCapacity(16);
      }

345 346 347 348
      current_task_ = nullptr;
      gpu_free_channel_->Put(current_task_);

      table_id_ = 0;
349

350 351
      // start build cpu&gpu ps thread
      start_build_thread();
T
Thunderbrook 已提交
352 353
    }
  }
Y
yaoxuefeng 已提交
354

355 356 357 358 359 360
  void SetSparseSGD(float nonclk_coeff,
                    float clk_coeff,
                    float min_bound,
                    float max_bound,
                    float learning_rate,
                    float initial_g2sum,
D
danleifeng 已提交
361 362 363 364
                    float initial_range,
                    float beta1_decay_rate,
                    float beta2_decay_rate,
                    float ada_epsilon);
365 366 367 368 369
  void SetEmbedxSGD(float mf_create_thresholds,
                    float mf_learning_rate,
                    float mf_initial_g2sum,
                    float mf_initial_range,
                    float mf_min_bound,
D
danleifeng 已提交
370 371 372
                    float mf_max_bound,
                    float mf_beta1_decay_rate,
                    float mf_beta2_decay_rate,
D
danleifeng 已提交
373 374 375
                    float mf_ada_epsilon,
                    float nodeid_slot,
                    float feature_learning_rate);
D
danleifeng 已提交
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426

#ifdef PADDLE_WITH_PSCORE
  void add_sparse_optimizer(
      std::unordered_map<std::string, float>& config,  // NOLINT
      const ::paddle::distributed::SparseCommonSGDRuleParameter& sgd_param,
      const std::string& prefix = "") {
    auto optimizer_name = sgd_param.name();
    if (optimizer_name == "SparseNaiveSGDRule") {
      config[prefix + "optimizer_type"] = 0;
      config[prefix + "learning_rate"] = sgd_param.naive().learning_rate();
      config[prefix + "initial_range"] = sgd_param.naive().initial_range();
      config[prefix + "min_bound"] = sgd_param.naive().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.naive().weight_bounds()[1];
    } else if (optimizer_name == "SparseAdaGradSGDRule") {
      config[prefix + "optimizer_type"] = 1;
      config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adagrad().initial_range();
      config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum();
      config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1];
    } else if (optimizer_name == "StdAdaGradSGDRule") {
      config[prefix + "optimizer_type"] = 2;
      config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adagrad().initial_range();
      config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum();
      config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1];
    } else if (optimizer_name == "SparseAdamSGDRule") {
      config[prefix + "optimizer_type"] = 3;
      config[prefix + "learning_rate"] = sgd_param.adam().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adam().initial_range();
      config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate();
      config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate();
      config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon();
      config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1];
    } else if (optimizer_name == "SparseSharedAdamSGDRule") {
      config[prefix + "optimizer_type"] = 4;
      config[prefix + "learning_rate"] = sgd_param.adam().learning_rate();
      config[prefix + "initial_range"] = sgd_param.adam().initial_range();
      config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate();
      config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate();
      config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon();
      config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0];
      config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1];
    }
  }

  void InitializeGPUServer(paddle::distributed::PSParameter ps_param) {
    auto sparse_table =
        ps_param.server_param().downpour_server_param().downpour_table_param(0);
D
danleifeng 已提交
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
    // set build thread_num and shard_num
    thread_keys_thread_num_ = sparse_table.shard_num();
    thread_keys_shard_num_ = sparse_table.shard_num();
    VLOG(1) << "ps_gpu build phase thread_num:" << thread_keys_thread_num_
            << " shard_num:" << thread_keys_shard_num_;

    pull_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < pull_thread_pool_.size(); i++) {
      pull_thread_pool_[i].reset(new ::ThreadPool(1));
    }
    hbm_thread_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
      hbm_thread_pool_[i].reset(new ::ThreadPool(1));
    }

L
lxsbupt 已提交
442 443 444 445 446
    cpu_work_pool_.resize(thread_keys_shard_num_);
    for (size_t i = 0; i < hbm_thread_pool_.size(); i++) {
      cpu_work_pool_[i].reset(new ::ThreadPool(16));
    }

D
danleifeng 已提交
447 448 449 450 451 452 453 454 455 456 457
    auto sparse_table_accessor = sparse_table.accessor();
    auto sparse_table_accessor_parameter =
        sparse_table_accessor.ctr_accessor_param();
    accessor_class_ = sparse_table_accessor.accessor_class();

    std::unordered_map<std::string, float> config;
    config["embedx_dim"] = sparse_table_accessor.embedx_dim();
    config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff();
    config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff();
    config["mf_create_thresholds"] = sparse_table_accessor.embedx_threshold();

D
danleifeng 已提交
458 459 460 461 462
    config["nodeid_slot"] =
        sparse_table_accessor.graph_sgd_param().nodeid_slot();
    config["feature_learning_rate"] =
        sparse_table_accessor.graph_sgd_param().feature_learning_rate();

D
danleifeng 已提交
463 464 465 466 467 468 469 470
    if (accessor_class_ == "CtrDymfAccessor") {
      // optimizer config for embed_w and embedx
      add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param());
      add_sparse_optimizer(
          config, sparse_table_accessor.embedx_sgd_param(), "mf_");
    }

    fleet_config_ = config;
D
danleifeng 已提交
471 472
    GlobalAccessorFactory::GetInstance().Init(accessor_class_);
    GlobalAccessorFactory::GetInstance().GetAccessorWrapper()->Configure(
D
danleifeng 已提交
473 474 475 476 477
        config);
    InitializeGPUServer(config);
  }
#endif

Y
yaoxuefeng 已提交
478 479 480 481 482 483 484
  void InitializeGPUServer(std::unordered_map<std::string, float> config) {
    float nonclk_coeff = (config.find("nonclk_coeff") == config.end())
                             ? 1.0
                             : config["nonclk_coeff"];
    float clk_coeff =
        (config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"];
    float min_bound = (config.find("min_bound") == config.end())
D
danleifeng 已提交
485
                          ? -10.0
Y
yaoxuefeng 已提交
486
                          : config["min_bound"];
D
danleifeng 已提交
487 488
    float max_bound =
        (config.find("max_bound") == config.end()) ? 10.0 : config["max_bound"];
Y
yaoxuefeng 已提交
489
    float learning_rate = (config.find("learning_rate") == config.end())
D
danleifeng 已提交
490
                              ? 0.05
Y
yaoxuefeng 已提交
491 492
                              : config["learning_rate"];
    float initial_g2sum = (config.find("initial_g2sum") == config.end())
D
danleifeng 已提交
493
                              ? 3.0
Y
yaoxuefeng 已提交
494 495
                              : config["initial_g2sum"];
    float initial_range = (config.find("initial_range") == config.end())
D
danleifeng 已提交
496
                              ? 1e-4
Y
yaoxuefeng 已提交
497
                              : config["initial_range"];
D
danleifeng 已提交
498 499 500 501 502 503 504 505 506
    float beta1_decay_rate = (config.find("beta1_decay_rate") == config.end())
                                 ? 0.9
                                 : config["beta1_decay_rate"];
    float beta2_decay_rate = (config.find("beta2_decay_rate") == config.end())
                                 ? 0.999
                                 : config["beta2_decay_rate"];
    float ada_epsilon = (config.find("ada_epsilon") == config.end())
                            ? 1e-8
                            : config["ada_epsilon"];
Y
yaoxuefeng 已提交
507 508 509 510 511 512
    // mf config settings
    float mf_create_thresholds =
        (config.find("mf_create_thresholds") == config.end())
            ? static_cast<float>(1.0)
            : config["mf_create_thresholds"];
    float mf_learning_rate = (config.find("mf_learning_rate") == config.end())
D
danleifeng 已提交
513
                                 ? 0.05
Y
yaoxuefeng 已提交
514 515
                                 : config["mf_learning_rate"];
    float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end())
D
danleifeng 已提交
516
                                 ? 3.0
Y
yaoxuefeng 已提交
517 518
                                 : config["mf_initial_g2sum"];
    float mf_initial_range = (config.find("mf_initial_range") == config.end())
D
danleifeng 已提交
519
                                 ? 1e-4
Y
yaoxuefeng 已提交
520 521
                                 : config["mf_initial_range"];
    float mf_min_bound = (config.find("mf_min_bound") == config.end())
D
danleifeng 已提交
522
                             ? -10.0
Y
yaoxuefeng 已提交
523 524
                             : config["mf_min_bound"];
    float mf_max_bound = (config.find("mf_max_bound") == config.end())
D
danleifeng 已提交
525
                             ? 10.0
Y
yaoxuefeng 已提交
526
                             : config["mf_max_bound"];
D
danleifeng 已提交
527 528 529 530 531 532 533 534 535 536 537
    float mf_beta1_decay_rate =
        (config.find("mf_beta1_decay_rate") == config.end())
            ? 0.9
            : config["mf_beta1_decay_rate"];
    float mf_beta2_decay_rate =
        (config.find("mf_beta2_decay_rate") == config.end())
            ? 0.999
            : config["mf_beta2_decay_rate"];
    float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end())
                               ? 1e-8
                               : config["mf_ada_epsilon"];
D
danleifeng 已提交
538 539 540 541 542 543 544 545 546 547

    float feature_learning_rate =
        (config.find("feature_learning_rate") == config.end())
            ? 0.05
            : config["feature_learning_rate"];

    float nodeid_slot = (config.find("nodeid_slot") == config.end())
                            ? 9008
                            : config["nodeid_slot"];

548 549 550 551 552 553
    this->SetSparseSGD(nonclk_coeff,
                       clk_coeff,
                       min_bound,
                       max_bound,
                       learning_rate,
                       initial_g2sum,
D
danleifeng 已提交
554 555 556 557
                       initial_range,
                       beta1_decay_rate,
                       beta2_decay_rate,
                       ada_epsilon);
558 559 560 561 562
    this->SetEmbedxSGD(mf_create_thresholds,
                       mf_learning_rate,
                       mf_initial_g2sum,
                       mf_initial_range,
                       mf_min_bound,
D
danleifeng 已提交
563 564 565
                       mf_max_bound,
                       mf_beta1_decay_rate,
                       mf_beta2_decay_rate,
D
danleifeng 已提交
566 567 568
                       mf_ada_epsilon,
                       nodeid_slot,
                       feature_learning_rate);
D
danleifeng 已提交
569 570 571 572

    // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam)
    optimizer_type_ = (config.find("optimizer_type") == config.end())
                          ? 1
Z
zmxdream 已提交
573
                          : static_cast<int>(config["optimizer_type"]);
D
danleifeng 已提交
574 575 576 577

    VLOG(0) << "InitializeGPUServer optimizer_type_:" << optimizer_type_
            << " nodeid_slot:" << nodeid_slot
            << " feature_learning_rate:" << feature_learning_rate;
Y
yaoxuefeng 已提交
578
  }
F
Fan Zhang 已提交
579

580 581 582 583 584 585
  void SetDate(int year, int month, int day) {
    year_ = year;
    month_ = month;
    day_ = day;
  }

Y
yaoxuefeng 已提交
586 587
  void SetDataset(Dataset* dataset) { dataset_ = dataset; }

T
Thunderbrook 已提交
588 589
  // PSGPUWrapper singleton
  static std::shared_ptr<PSGPUWrapper> GetInstance() {
590 591 592 593 594
    {
      std::lock_guard<std::mutex> lk(ins_mutex);
      if (NULL == s_instance_) {
        s_instance_.reset(new paddle::framework::PSGPUWrapper());
      }
T
Thunderbrook 已提交
595 596 597 598 599 600 601 602 603
    }
    return s_instance_;
  }
  std::vector<std::unordered_map<uint64_t, std::vector<float>>>& GetLocalTable(
      int table_id) {
    return local_tables_[table_id];
  }
  void SetSlotVector(const std::vector<int>& slot_vector) {
    slot_vector_ = slot_vector;
L
lxsbupt 已提交
604
    VLOG(0) << "slot_vector size is " << slot_vector_.size();
T
Thunderbrook 已提交
605 606
  }

Y
yaoxuefeng 已提交
607 608
  void SetSlotOffsetVector(const std::vector<int>& slot_offset_vector) {
    slot_offset_vector_ = slot_offset_vector;
Y
yaoxuefeng 已提交
609 610 611 612 613
    std::cout << "yxf set: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
Y
yaoxuefeng 已提交
614 615
  }

F
Fan Zhang 已提交
616
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
617 618 619
  void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
    slot_mf_dim_vector_ = slot_mf_dim_vector;
    assert(slot_mf_dim_vector_.size() == slot_vector_.size());
Y
yaoxuefeng 已提交
620 621 622 623 624 625
  }

  void InitSlotInfo() {
    if (slot_info_initialized_) {
      return;
    }
Z
zmxdream 已提交
626
    SlotRecordDataset* dataset = reinterpret_cast<SlotRecordDataset*>(dataset_);
Y
yaoxuefeng 已提交
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
    auto slots_vec = dataset->GetSlots();
    slot_offset_vector_.clear();
    for (auto& slot : slot_vector_) {
      for (size_t i = 0; i < slots_vec.size(); ++i) {
        if (std::to_string(slot) == slots_vec[i]) {
          slot_offset_vector_.push_back(i);
          break;
        }
      }
    }
    std::cout << "psgpu wrapper use slots: ";
    for (auto s : slot_offset_vector_) {
      std::cout << s << " | ";
    }
    std::cout << " end " << std::endl;
    for (size_t i = 0; i < slot_mf_dim_vector_.size(); i++) {
Y
yaoxuefeng 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
      slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i];
    }

    std::unordered_set<int> dims_set;
    for (auto& it : slot_dim_map_) {
      dims_set.insert(it.second);
    }
    size_t num_of_dim = dims_set.size();
    index_dim_vec_.resize(num_of_dim);
    index_dim_vec_.assign(dims_set.begin(), dims_set.end());
    std::sort(index_dim_vec_.begin(), index_dim_vec_.end());
    std::unordered_map<int, int> dim_index_map;
    for (size_t i = 0; i < num_of_dim; i++) {
      dim_index_map[index_dim_vec_[i]] = i;
    }
658
    hbm_pools_.resize(resource_->total_device() * num_of_dim);
L
lxsbupt 已提交
659 660 661 662
    for (size_t i = 0; i < hbm_pools_.size(); i++) {
      hbm_pools_[i] = new HBMMemoryPoolFix();
    }

663
    mem_pools_.resize(resource_->total_device() * num_of_dim);
Y
yaoxuefeng 已提交
664 665 666 667 668 669 670
    max_mf_dim_ = index_dim_vec_.back();
    multi_mf_dim_ = (dim_index_map.size() >= 1) ? dim_index_map.size() : 0;
    resource_->set_multi_mf(multi_mf_dim_, max_mf_dim_);
    slot_index_vec_.resize(slot_mf_dim_vector_.size());
    for (size_t i = 0; i < slot_index_vec_.size(); i++) {
      slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]];
    }
D
danleifeng 已提交
671 672

    auto accessor_wrapper_ptr =
D
danleifeng 已提交
673
        GlobalAccessorFactory::GetInstance().GetAccessorWrapper();
D
danleifeng 已提交
674 675
    val_type_size_ = accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_);
    grad_type_size_ = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_);
D
danleifeng 已提交
676
    pull_type_size_ = accessor_wrapper_ptr->GetPullValueSize(max_mf_dim_);
D
danleifeng 已提交
677
    VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_
D
danleifeng 已提交
678 679
            << " grad_type_size_:" << grad_type_size_
            << " pull_type_size_:" << pull_type_size_;
Y
yaoxuefeng 已提交
680
    slot_info_initialized_ = true;
Y
yaoxuefeng 已提交
681
  }
F
Fan Zhang 已提交
682
#endif
Y
yaoxuefeng 已提交
683

T
Thunderbrook 已提交
684 685
  void ShowOneTable(int index) { HeterPs_->show_one_table(index); }

T
Thunderbrook 已提交
686 687 688 689 690 691 692 693
  int UseAfsApi() { return use_afs_api_; }

#ifdef PADDLE_WITH_PSLIB
  std::shared_ptr<paddle::ps::AfsReader> OpenReader(
      const std::string& filename) {
    return afs_handler_.open_reader(filename);
  }

Z
zmxdream 已提交
694 695 696 697 698
  std::shared_ptr<paddle::ps::AfsWriter> OpenWriter(
      const std::string& filename) {
    return afs_handler_.open_writer(filename);
  }

699 700 701 702
  void InitAfsApi(const std::string& fs_name,
                  const std::string& fs_user,
                  const std::string& pass_wd,
                  const std::string& conf);
T
Thunderbrook 已提交
703 704
#endif

D
danleifeng 已提交
705 706 707 708 709 710
#ifdef PADDLE_WITH_PSCORE
  void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) {
    cpu_table_accessor_ = accessor;
  }
#endif

T
Thunderbrook 已提交
711 712
 private:
  static std::shared_ptr<PSGPUWrapper> s_instance_;
713
  static std::mutex ins_mutex;
Y
yaoxuefeng 已提交
714
  Dataset* dataset_;
T
Thunderbrook 已提交
715 716 717
#ifdef PADDLE_WITH_PSLIB
  paddle::ps::AfsApiWrapper afs_handler_;
#endif
T
Thunderbrook 已提交
718
  std::unordered_map<
719 720
      uint64_t,
      std::vector<std::unordered_map<uint64_t, std::vector<float>>>>
T
Thunderbrook 已提交
721
      local_tables_;
L
lxsbupt 已提交
722 723
  HeterPsBase* HeterPs_ = NULL;
  // std::vector<LoDTensor> keys_tensor;  // Cache for pull_sparse
724
  std::vector<phi::DenseTensor> keys_tensor;  // Cache for pull_sparse
T
Thunderbrook 已提交
725 726 727
  std::shared_ptr<HeterPsResource> resource_;
  int32_t sleep_seconds_before_fail_exit_;
  std::vector<int> slot_vector_;
Y
yaoxuefeng 已提交
728 729 730 731 732 733 734 735 736
  std::vector<int> slot_offset_vector_;
  std::vector<int> slot_mf_dim_vector_;
  std::unordered_map<int, int> slot_dim_map_;
  std::vector<int> slot_index_vec_;
  std::vector<int> index_dim_vec_;
  int multi_mf_dim_{0};
  int max_mf_dim_{0};
  size_t val_type_size_{0};
  size_t grad_type_size_{0};
D
danleifeng 已提交
737
  size_t pull_type_size_{0};
Y
yaoxuefeng 已提交
738 739 740 741 742 743

  double time_1 = 0.0;
  double time_2 = 0.0;
  double time_3 = 0.0;
  double time_4 = 0.0;

T
Thunderbrook 已提交
744
  int multi_node_{0};
L
lxsbupt 已提交
745
  int rank_id_;
746
  int node_size_;
747
  uint64_t table_id_;
D
danleifeng 已提交
748
  int gpu_graph_mode_ = 0;
F
Fan Zhang 已提交
749
#ifdef PADDLE_WITH_CUDA
750 751 752
  std::vector<ncclComm_t> inner_comms_;
  std::vector<ncclComm_t> inter_comms_;
  std::vector<ncclUniqueId> inter_ncclids_;
F
Fan Zhang 已提交
753
#endif
Y
yaoxuefeng 已提交
754 755 756
  std::vector<int> heter_devices_;
  std::unordered_set<std::string> gpu_ps_config_keys_;
  HeterObjectPool<HeterContext> gpu_task_pool_;
757
  std::vector<std::vector<robin_hood::unordered_set<uint64_t>>> thread_keys_;
758 759
  std::vector<std::vector<std::vector<robin_hood::unordered_set<uint64_t>>>>
      thread_dim_keys_;
Y
yaoxuefeng 已提交
760 761 762
  int thread_keys_thread_num_ = 37;
  int thread_keys_shard_num_ = 37;
  uint64_t max_fea_num_per_pass_ = 5000000000;
763 764 765
  int year_;
  int month_;
  int day_;
Y
yaoxuefeng 已提交
766
  bool slot_info_initialized_ = false;
L
lxsbupt 已提交
767
  bool hbm_sparse_table_initialized_ = false;
T
Thunderbrook 已提交
768
  int use_afs_api_ = 0;
D
danleifeng 已提交
769 770 771 772 773 774
  int optimizer_type_ = 1;
  std::string accessor_class_;
  std::unordered_map<std::string, float> fleet_config_;
#ifdef PADDLE_WITH_PSCORE
  paddle::distributed::ValueAccessor* cpu_table_accessor_;
#endif
T
Thunderbrook 已提交
775

F
Fan Zhang 已提交
776
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
777
  std::vector<MemoryPool*> mem_pools_;
L
lxsbupt 已提交
778 779
  std::vector<HBMMemoryPoolFix*> hbm_pools_;  // in multi mfdim, one table need
                                              // hbm pools of totol dims number
F
Fan Zhang 已提交
780
#endif
Y
yaoxuefeng 已提交
781

782 783 784 785 786 787 788 789 790 791 792 793
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      data_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      buildcpu_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      gpu_free_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
L
lxsbupt 已提交
794 795 796 797 798 799
  std::shared_ptr<
      paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
      buildpull_ready_channel_ =
          paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
  std::vector<std::shared_ptr<paddle::framework::ChannelObject<task_info>>>
      cpu_reday_channels_;
800
  std::shared_ptr<HeterContext> current_task_ = nullptr;
801
  std::thread pre_build_threads_;
L
lxsbupt 已提交
802
  std::thread buildpull_threads_;
803
  bool running_ = false;
Y
yaoxuefeng 已提交
804
  std::vector<std::shared_ptr<ThreadPool>> pull_thread_pool_;
T
Thunderbrook 已提交
805
  std::vector<std::shared_ptr<ThreadPool>> hbm_thread_pool_;
L
lxsbupt 已提交
806
  std::vector<std::shared_ptr<ThreadPool>> cpu_work_pool_;
D
danleifeng 已提交
807
  OptimizerConfig optimizer_config_;
L
lxsbupt 已提交
808 809 810 811
  // gradient push count
  uint64_t grad_push_count_ = 0;
  // infer mode
  bool infer_mode_ = false;
812

T
Thunderbrook 已提交
813 814 815 816 817 818 819
 protected:
  static bool is_initialized_;
};

}  // end namespace framework
}  // end namespace paddle
#endif