ps_gpu_wrapper.cc 58.2 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
T
Thunderbrook 已提交
28
#ifdef PADDLE_WITH_HETERPS
Y
yaoxuefeng 已提交
29

30 31
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"

T
Thunderbrook 已提交
32
#include <algorithm>
Y
yaoxuefeng 已提交
33 34
#include <deque>

D
danleifeng 已提交
35
#include "paddle/fluid/framework/data_set.h"
D
danleifeng 已提交
36
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
T
Thunderbrook 已提交
37
#include "paddle/fluid/platform/timer.h"
38 39 40
#if defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#endif
T
Thunderbrook 已提交
41

D
danleifeng 已提交
42 43
DECLARE_int32(gpugraph_dedup_pull_push_mode);

T
Thunderbrook 已提交
44 45 46
namespace paddle {
namespace framework {

T
Thunderbrook 已提交
47
#ifdef PADDLE_WITH_PSLIB
48 49 50 51 52 53
void AfsWrapper::init(const std::string& fs_name,
                      const std::string& fs_user,
                      const std::string& pass_wd,
                      const std::string& conf) {
  int ret = afs_handler_.init(
      fs_name.c_str(), fs_user.c_str(), pass_wd.c_str(), conf.c_str());
T
Thunderbrook 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
  if (ret != 0) {
    LOG(ERROR) << "AFS Init Error";
  }
}

int AfsWrapper::remove(const std::string& path) {
  return afs_handler_.remove(path);
}

int AfsWrapper::mkdir(const std::string& path) {
  return afs_handler_.mkdir(path);
}

std::vector<std::string> AfsWrapper::list(const std::string& path) {
  return afs_handler_.list(path);
}

int AfsWrapper::exist(const std::string& path) {
  return afs_handler_.exist(path);
}

int AfsWrapper::upload(const std::string& local_file,
                       const std::string& afs_file) {
  return afs_handler_.upload_file(local_file, afs_file);
}

int AfsWrapper::download(const std::string& local_file,
                         const std::string& afs_file) {
  return afs_handler_.download_file(local_file, afs_file);
}
84 85 86 87 88 89 90 91 92 93 94 95

int AfsWrapper::touchz(const std::string& path) {
  return afs_handler_.touchz(path);
}

std::string AfsWrapper::cat(const std::string& path) {
  return afs_handler_.cat(path);
}

int AfsWrapper::mv(const std::string& old_path, const std::string& dest_path) {
  return afs_handler_.mv(old_path, dest_path);
}
T
Thunderbrook 已提交
96 97
#endif

T
Thunderbrook 已提交
98 99
std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL;
bool PSGPUWrapper::is_initialized_ = false;
100
std::mutex PSGPUWrapper::ins_mutex;
T
Thunderbrook 已提交
101 102 103 104 105
#ifdef PADDLE_WITH_PSLIB
void PSGPUWrapper::InitAfsApi(const std::string& fs_name,
                              const std::string& fs_user,
                              const std::string& pass_wd,
                              const std::string& conf) {
106 107
  int ret = afs_handler_.init(
      fs_name.c_str(), fs_user.c_str(), pass_wd.c_str(), conf.c_str());
T
Thunderbrook 已提交
108
  if (ret != 0) {
109
    VLOG(0) << "AFS Init Error";
T
Thunderbrook 已提交
110 111 112 113
  }
  use_afs_api_ = 1;
}
#endif
114
void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
Y
yaoxuefeng 已提交
115
  VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin";
T
Thunderbrook 已提交
116 117
  platform::Timer timeline;
  timeline.Start();
118
  int device_num = heter_devices_.size();
Y
yaoxuefeng 已提交
119
  gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_);
120

Y
yaoxuefeng 已提交
121
  std::vector<std::thread> threads;
Y
yaoxuefeng 已提交
122 123 124 125 126 127 128
  // data should be in input channel

  thread_dim_keys_.resize(thread_keys_thread_num_);
  for (int i = 0; i < thread_keys_thread_num_; i++) {
    thread_dim_keys_[i].resize(thread_keys_shard_num_);
    for (int j = 0; j < thread_keys_shard_num_; j++) {
      thread_dim_keys_[i][j].resize(multi_mf_dim_);
129
    }
Y
yaoxuefeng 已提交
130
  }
Y
yaoxuefeng 已提交
131 132 133 134

  size_t total_len = 0;
  size_t len_per_thread = 0;
  int remain = 0;
Y
yaoxuefeng 已提交
135
  size_t begin = 0;
Y
yaoxuefeng 已提交
136 137 138

  std::string data_set_name = std::string(typeid(*dataset_).name());

D
danleifeng 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
  VLOG(0) << "gpu_graph_mode_:" << gpu_graph_mode_;
  if (!gpu_graph_mode_) {
    if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
      VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset";
      SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_);
      auto input_channel = dataset->GetInputChannel();
      VLOG(0) << "psgpu wrapperinputslotchannle size: "
              << input_channel->Size();
      const std::deque<SlotRecord>& vec_data = input_channel->GetData();
      total_len = vec_data.size();
      len_per_thread = total_len / thread_keys_thread_num_;
      remain = total_len % thread_keys_thread_num_;
      VLOG(0) << "total len: " << total_len;
      auto gen_dynamic_mf_func = [this](
                                     const std::deque<SlotRecord>& total_data,
                                     int begin_index,
                                     int end_index,
                                     int i) {
        for (auto iter = total_data.begin() + begin_index;
             iter != total_data.begin() + end_index;
             iter++) {
          const auto& ins = *iter;
          const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
          const auto& slot_offset = ins->slot_uint64_feasigns_.slot_offsets;
          for (size_t slot_idx = 0; slot_idx < slot_offset_vector_.size();
               slot_idx++) {
            for (size_t j = slot_offset[slot_offset_vector_[slot_idx]];
                 j < slot_offset[slot_offset_vector_[slot_idx] + 1];
                 j++) {
              int shard_id = feasign_v[j] % thread_keys_shard_num_;
              int dim_id = slot_index_vec_[slot_idx];
              if (feasign_v[j] != 0) {
                this->thread_dim_keys_[i][shard_id][dim_id].insert(
                    feasign_v[j]);
              }
Y
yaoxuefeng 已提交
174
            }
175 176
          }
        }
D
danleifeng 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
      };
      for (int i = 0; i < thread_keys_thread_num_; i++) {
        threads.push_back(
            std::thread(gen_dynamic_mf_func,
                        std::ref(vec_data),
                        begin,
                        begin + len_per_thread + (i < remain ? 1 : 0),
                        i));

        begin += len_per_thread + (i < remain ? 1 : 0);
      }
      for (std::thread& t : threads) {
        t.join();
      }
      timeline.Pause();
      VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec()
              << " seconds.";
    } else {
      CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos);
      VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset";
      MultiSlotDataset* dataset = (MultiSlotDataset*)(dataset_);
      auto input_channel = dataset->GetInputChannel();

      const std::deque<Record>& vec_data = input_channel->GetData();
      total_len = vec_data.size();
      len_per_thread = total_len / thread_keys_thread_num_;
      remain = total_len % thread_keys_thread_num_;
      auto gen_func = [this](const std::deque<Record>& total_data,
                             int begin_index,
                             int end_index,
                             int i) {
        for (auto iter = total_data.begin() + begin_index;
             iter != total_data.begin() + end_index;
             iter++) {
          const auto& ins = *iter;
          const auto& feasign_v = ins.uint64_feasigns_;
          for (const auto feasign : feasign_v) {
            uint64_t cur_key = feasign.sign().uint64_feasign_;
            int shard_id = cur_key % thread_keys_shard_num_;
            this->thread_keys_[i][shard_id].insert(cur_key);
          }
        }
      };
      for (int i = 0; i < thread_keys_thread_num_; i++) {
        threads.push_back(
            std::thread(gen_func,
                        std::ref(vec_data),
                        begin,
                        begin + len_per_thread + (i < remain ? 1 : 0),
                        i));
        begin += len_per_thread + (i < remain ? 1 : 0);
228
      }
D
danleifeng 已提交
229 230 231 232 233 234
      for (std::thread& t : threads) {
        t.join();
      }
      timeline.Pause();
      VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec()
              << " seconds.";
Y
yaoxuefeng 已提交
235 236
    }
  } else {
D
danleifeng 已提交
237 238 239
    VLOG(0) << "PreBuild in GpuGraph mode";
    SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_);
    const std::vector<uint64_t>& vec_data = dataset->GetGpuGraphTotalKeys();
Y
yaoxuefeng 已提交
240 241 242

    total_len = vec_data.size();
    len_per_thread = total_len / thread_keys_thread_num_;
D
danleifeng 已提交
243
    VLOG(0) << "GpuGraphTotalKeys: " << total_len;
Y
yaoxuefeng 已提交
244
    remain = total_len % thread_keys_thread_num_;
D
danleifeng 已提交
245 246 247 248
    auto gen_graph_data_func = [this](const std::vector<uint64_t>& total_data,
                                      int begin_index,
                                      int end_index,
                                      int i) {
Y
yaoxuefeng 已提交
249
      for (auto iter = total_data.begin() + begin_index;
250 251
           iter != total_data.begin() + end_index;
           iter++) {
D
danleifeng 已提交
252 253 254
        uint64_t cur_key = *iter;
        int shard_id = cur_key % thread_keys_shard_num_;
        this->thread_keys_[i][shard_id].insert(cur_key);
Y
yaoxuefeng 已提交
255 256
      }
    };
D
danleifeng 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270
    auto gen_graph_dynamic_mf_func =
        [this](const std::vector<uint64_t>& total_data,
               int begin_index,
               int end_index,
               int i) {
          for (auto iter = total_data.begin() + begin_index;
               iter != total_data.begin() + end_index;
               iter++) {
            uint64_t cur_key = *iter;
            int shard_id = cur_key % thread_keys_shard_num_;
            // TODO: feasign <-> slot <-> multi_dim
            this->thread_dim_keys_[i][shard_id][0].insert(cur_key);
          }
        };
Y
yaoxuefeng 已提交
271
    for (int i = 0; i < thread_keys_thread_num_; i++) {
D
danleifeng 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
      if (!multi_mf_dim_) {
        VLOG(1) << "psgpu graph wrapper genfunc";
        threads.push_back(
            std::thread(gen_graph_data_func,
                        std::ref(vec_data),
                        begin,
                        begin + len_per_thread + (i < remain ? 1 : 0),
                        i));
      } else {
        VLOG(1) << "psgpu graph wrapper genfunc with dynamic mf";
        threads.push_back(
            std::thread(gen_graph_dynamic_mf_func,
                        std::ref(vec_data),
                        begin,
                        begin + len_per_thread + (i < remain ? 1 : 0),
                        i));
      }
Y
yaoxuefeng 已提交
289 290 291 292 293
      begin += len_per_thread + (i < remain ? 1 : 0);
    }
    for (std::thread& t : threads) {
      t.join();
    }
Y
yaoxuefeng 已提交
294 295 296 297
  }

  timeline.Start();

298
  threads.clear();
Y
yaoxuefeng 已提交
299
  // merge thread_keys to shard_keys
300 301
  auto merge_ins_dynamic_mf_func = [this, gpu_task](int shard_num, int dim_id) {
    for (int i = 0; i < thread_keys_thread_num_; ++i) {
302 303
      gpu_task->batch_add_keys(
          shard_num, dim_id, thread_dim_keys_[i][shard_num][dim_id]);
304 305 306
      thread_dim_keys_[i][shard_num][dim_id].clear();
    }
  };
307
  for (int i = 0; i < thread_keys_shard_num_; ++i) {
Y
yaoxuefeng 已提交
308 309
    for (int j = 0; j < multi_mf_dim_; j++) {
      threads.push_back(std::thread(merge_ins_dynamic_mf_func, i, j));
310
    }
311 312 313
  }
  for (auto& t : threads) {
    t.join();
Y
yaoxuefeng 已提交
314 315 316
  }
  timeline.Pause();

317
  VLOG(0) << "GpuPs task add keys cost " << timeline.ElapsedSec()
Y
yaoxuefeng 已提交
318 319 320 321 322
          << " seconds.";
  timeline.Start();
  gpu_task->UniqueKeys();
  timeline.Pause();

323
  VLOG(0) << "GpuPs task unique cost " << timeline.ElapsedSec() << " seconds.";
Y
yaoxuefeng 已提交
324 325
  for (int i = 0; i < thread_keys_shard_num_; i++) {
    for (int j = 0; j < multi_mf_dim_; j++) {
D
danleifeng 已提交
326 327 328
      if (i == 0 && j == multi_mf_dim_ - 1) {
        gpu_task->feature_dim_keys_[i][j].push_back(0);
      }
Y
yaoxuefeng 已提交
329 330 331 332
      VLOG(0) << "GpuPs shard: " << i << "mf dim: " << index_dim_vec_[j]
              << " key len: " << gpu_task->feature_dim_keys_[i][j].size();
      gpu_task->value_dim_ptr_[i][j].resize(
          gpu_task->feature_dim_keys_[i][j].size());
333
    }
Y
yaoxuefeng 已提交
334
  }
335 336 337 338
}

void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
  platform::Timer timeline;
T
Thunderbrook 已提交
339
  std::vector<std::future<void>> task_futures;
340 341 342 343
  int device_num = heter_devices_.size();
  auto& local_keys = gpu_task->feature_keys_;
  auto& local_ptr = gpu_task->value_ptr_;

344 345 346
  auto& local_dim_keys = gpu_task->feature_dim_keys_;
  auto& local_dim_ptr = gpu_task->value_dim_ptr_;

347 348
  auto& device_keys = gpu_task->device_keys_;
  auto& device_vals = gpu_task->device_values_;
349 350 351
  auto& device_dim_keys = gpu_task->device_dim_keys_;
  auto& device_dim_ptr = gpu_task->device_dim_ptr_;
  auto& device_dim_mutex = gpu_task->dim_mutex_;
Y
yaoxuefeng 已提交
352 353 354 355

  for (size_t dev = 0; dev < device_dim_keys.size(); dev++) {
    device_dim_keys[dev].resize(multi_mf_dim_);
    device_dim_ptr[dev].resize(multi_mf_dim_);
356
  }
Y
yaoxuefeng 已提交
357

T
Thunderbrook 已提交
358
  // auto& device_mutex = gpu_task->mutex_;
359 360 361 362 363 364

  std::vector<std::thread> threads(thread_keys_shard_num_);
#ifdef PADDLE_WITH_PSLIB
  auto fleet_ptr = FleetWrapper::GetInstance();
#endif
#ifdef PADDLE_WITH_PSCORE
365
  auto fleet_ptr = paddle::distributed::FleetWrapper::GetInstance();
366
#endif
367

368
#if (defined PADDLE_WITH_PSLIB) && (defined PADDLE_WITH_HETERPS)
369 370 371 372 373 374 375 376 377 378 379
  // get day_id: day nums from 1970
  struct std::tm b;
  b.tm_year = year_ - 1900;
  b.tm_mon = month_ - 1;
  b.tm_mday = day_;
  b.tm_min = b.tm_hour = b.tm_sec = 0;
  std::time_t seconds_from_1970 = std::mktime(&b);
  int day_id = seconds_from_1970 / 86400;
  fleet_ptr->pslib_ptr_->_worker_ptr->set_day_id(table_id_, day_id);
#endif

380
  timeline.Start();
381

382 383 384 385 386
  auto ptl_dynamic_mf_func =
      [this, &local_dim_keys, &local_dim_ptr, &fleet_ptr](int i, int j) {
        size_t key_size = local_dim_keys[i][j].size();
        int32_t status = -1;
        int32_t cnt = 0;
387
#ifdef PADDLE_WITH_PSLIB
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
        while (true) {
          auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
              i,
              reinterpret_cast<char**>(local_dim_ptr[i][j].data()),
              this->table_id_,
              local_dim_keys[i][j].data(),
              key_size);
          bool flag = true;

          tt.wait();

          try {
            status = tt.get();
          } catch (const std::future_error& e) {
            VLOG(0) << "Caught a future_error with code" << e.code()
                    << ", Message:" << e.what();
          }
          if (status != 0) {
            VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
            sleep(sleep_seconds_before_fail_exit_);
            flag = false;
            cnt++;
          }
          if (cnt > 3) {
            VLOG(0) << "fleet pull sparse failed, retry 3 times";
            exit(-1);
          }
415

416 417 418 419
          if (flag) {
            break;
          }
        }
420 421
#endif
#ifdef PADDLE_WITH_PSCORE
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
        while (true) {
          auto tt = fleet_ptr->worker_ptr_->PullSparsePtr(
              reinterpret_cast<char**>(local_dim_ptr[i][j].data()),
              this->table_id_,
              local_dim_keys[i][j].data(),
              key_size);
          bool flag = true;

          tt.wait();

          try {
            status = tt.get();
          } catch (const std::future_error& e) {
            VLOG(0) << "Caught a future_error with code" << e.code()
                    << ", Message:" << e.what();
          }
          if (status != 0) {
            VLOG(0) << "fleet pull sparse failed, status[" << status << "]";
            sleep(sleep_seconds_before_fail_exit_);
            flag = false;
            cnt++;
          }
          if (cnt > 3) {
            VLOG(0) << "fleet pull sparse failed, retry 3 times";
            exit(-1);
          }
448

449 450 451 452
          if (flag) {
            break;
          }
        }
453
#endif
454 455 456 457 458 459 460 461 462
        if (status != 0) {
          LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
          sleep(300);
          exit(-1);
        } else {
          VLOG(0) << "FleetWrapper Pull sparse to local done with table size: "
                  << local_dim_keys[i][j].size();
        }
      };
Y
yaoxuefeng 已提交
463 464 465 466 467 468

  threads.resize(thread_keys_shard_num_ * multi_mf_dim_);
  for (int i = 0; i < thread_keys_shard_num_; i++) {
    for (int j = 0; j < multi_mf_dim_; j++) {
      task_futures.emplace_back(
          pull_thread_pool_[i]->enqueue(ptl_dynamic_mf_func, i, j));
469
    }
470
  }
Y
yaoxuefeng 已提交
471 472
  for (auto& f : task_futures) {
    f.wait();
473
  }
Y
yaoxuefeng 已提交
474
  task_futures.clear();
475
  timeline.Pause();
T
Thunderbrook 已提交
476
  VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec()
477
          << " seconds.";
Y
yaoxuefeng 已提交
478 479 480 481 482 483 484 485
  if (multi_node_) {
    auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
    if (!gloo_wrapper->IsInitialized()) {
      VLOG(0) << "GLOO is not inited";
      gloo_wrapper->Init();
    }
    gloo_wrapper->Barrier();
  }
486 487

  timeline.Start();
Y
yaoxuefeng 已提交
488 489 490
  std::vector<std::vector<std::pair<uint64_t, char*>>> pass_values;

  bool record_status = false;
T
Thunderbrook 已提交
491 492
  auto& device_task_keys = gpu_task->device_task_keys_;
  auto& device_task_ptrs = gpu_task->device_task_ptr_;
493 494 495 496 497
  auto build_pull_dynamic_mf_func = [this,
                                     device_num,
                                     &local_dim_keys,
                                     &local_dim_ptr,
                                     &device_dim_keys,
Y
yaoxuefeng 已提交
498 499
                                     &device_dim_ptr,
                                     &device_dim_mutex](int i, int j) {
500
    std::vector<std::vector<FeatureKey>> task_keys(device_num);
501
#ifdef PADDLE_WITH_PSLIB
502 503
    std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> task_ptrs(
        device_num);
504 505 506 507 508 509
#endif

#ifdef PADDLE_WITH_PSCORE
    std::vector<std::vector<paddle::distributed::FixedFeatureValue*>> task_ptrs(
        device_num);
#endif
510 511 512 513 514
    for (size_t k = 0; k < local_dim_keys[i][j].size(); k++) {
      int shard = local_dim_keys[i][j][k] % device_num;
      task_keys[shard].push_back(local_dim_keys[i][j][k]);
      task_ptrs[shard].push_back(local_dim_ptr[i][j][k]);
    }
Y
yaoxuefeng 已提交
515
    // allocate local keys to devices
516
    for (int dev = 0; dev < device_num; dev++) {
Y
yaoxuefeng 已提交
517 518 519 520 521 522 523 524
      device_dim_mutex[dev][j]->lock();
      int len = task_keys[dev].size();
      int cur = device_dim_keys[dev][j].size();
      device_dim_keys[dev][j].resize(device_dim_keys[dev][j].size() + len);
      device_dim_ptr[dev][j].resize(device_dim_ptr[dev][j].size() + len);
      for (int k = 0; k < len; ++k) {
        device_dim_keys[dev][j][cur + k] = task_keys[dev][k];
        device_dim_ptr[dev][j][cur + k] = task_ptrs[dev][k];
525
      }
Y
yaoxuefeng 已提交
526
      device_dim_mutex[dev][j]->unlock();
527 528
    }
  };
529 530 531 532 533 534 535
  auto build_func = [device_num,
                     record_status,
                     &pass_values,
                     &local_keys,
                     &local_ptr,
                     &device_task_keys,
                     &device_task_ptrs](int i) {
T
Thunderbrook 已提交
536
    auto& task_keys = device_task_keys[i];
T
Thunderbrook 已提交
537
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
538
    auto& task_ptrs = device_task_ptrs[i];
T
Thunderbrook 已提交
539 540 541
#endif

#ifdef PADDLE_WITH_PSCORE
T
Thunderbrook 已提交
542
    auto& task_ptrs = device_task_ptrs[i];
T
Thunderbrook 已提交
543
#endif
544 545 546 547 548 549

    for (size_t j = 0; j < local_keys[i].size(); j++) {
      int shard = local_keys[i][j] % device_num;
      task_keys[shard].push_back(local_keys[i][j]);
      task_ptrs[shard].push_back(local_ptr[i][j]);
    }
550
#ifdef PADDLE_WITH_PSLIB
Y
yaoxuefeng 已提交
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
    if (record_status) {
      size_t local_keys_size = local_keys.size();
      size_t pass_values_size = pass_values.size();
      for (size_t j = 0; j < pass_values_size; j += local_keys_size) {
        auto& shard_values = pass_values[j];
        for (size_t pair_idx = 0; pair_idx < pass_values[j].size();
             pair_idx++) {
          auto& cur_pair = shard_values[pair_idx];
          int shard = cur_pair.first % device_num;
          task_keys[shard].push_back(cur_pair.first);
          task_ptrs[shard].push_back(
              (paddle::ps::DownpourFixedFeatureValue*)cur_pair.second);
        }
      }
    }
566
#endif
T
Thunderbrook 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
  };
  if (!multi_mf_dim_) {
    for (int i = 0; i < thread_keys_shard_num_; i++) {
      task_futures.emplace_back(hbm_thread_pool_[i]->enqueue(build_func, i));
    }
    for (auto& f : task_futures) {
      f.wait();
    }
    task_futures.clear();
    VLOG(0) << "GpuPs build hbmps done";
  }
  std::vector<std::vector<int>> prefix_sum;
  prefix_sum.resize(device_num);
  for (int i = 0; i < device_num; i++) {
    prefix_sum[i].resize(thread_keys_shard_num_ + 1);
    prefix_sum[i][0] = 0;
  }
584 585 586 587
  auto calc_prefix_func = [this,
                           &prefix_sum,
                           &device_keys,
                           &device_vals,
T
Thunderbrook 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
                           &device_task_keys](int device_num) {
    for (int j = 0; j < thread_keys_shard_num_; j++) {
      prefix_sum[device_num][j + 1] =
          prefix_sum[device_num][j] + device_task_keys[j][device_num].size();
    }
    device_keys[device_num].resize(
        prefix_sum[device_num][thread_keys_shard_num_]);
    device_vals[device_num].resize(
        prefix_sum[device_num][thread_keys_shard_num_]);
  };
  if (!multi_mf_dim_) {
    for (int i = 0; i < device_num; i++) {
      task_futures.emplace_back(
          hbm_thread_pool_[i]->enqueue(calc_prefix_func, i));
    }
    for (auto& f : task_futures) {
      f.wait();
    }
    task_futures.clear();
  }
  VLOG(0) << "prefix done";
609 610 611 612 613
  auto prepare_dev_value_func = [device_num,
                                 &prefix_sum,
                                 &device_keys,
                                 &device_vals,
                                 &device_task_keys,
T
Thunderbrook 已提交
614
                                 &device_task_ptrs](int dev, int shard_id) {
D
danleifeng 已提交
615
  // auto& task_keys = device_task_keys[shard_id];
T
Thunderbrook 已提交
616 617 618 619
#ifdef PADDLE_WITH_PSLIB
    auto& task_ptrs = device_task_ptrs[shard_id];
#endif

D
danleifeng 已提交
620 621 622
    // #ifdef PADDLE_WITH_PSCORE
    //     auto& task_ptrs = device_task_ptrs[shard_id];
    // #endif
623

D
danleifeng 已提交
624 625
    // int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id];
    // int cur = prefix_sum[dev][shard_id];
T
Thunderbrook 已提交
626
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
    for (int j = 0; j < len; ++j) {
      device_keys[dev][cur + j] = task_keys[dev][j];
      float* ptr_val = task_ptrs[dev][j]->data();
      FeatureValue& val = device_vals[dev][cur + j];
      size_t dim = task_ptrs[dev][j]->size();

      val.delta_score = ptr_val[1];
      val.show = ptr_val[2];
      val.clk = ptr_val[3];
      val.slot = ptr_val[6];
      val.lr = ptr_val[4];
      val.lr_g2sum = ptr_val[5];
      val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]);

      if (dim > 7) {
        val.mf_size = MF_DIM + 1;
        for (int x = 0; x < val.mf_size; x++) {
          val.mf[x] = ptr_val[x + 7];
        }
      } else {
        val.mf_size = 0;
        for (int x = 0; x < MF_DIM + 1; x++) {
          val.mf[x] = 0;
Y
yaoxuefeng 已提交
650 651
        }
      }
T
Thunderbrook 已提交
652
    }
T
Thunderbrook 已提交
653
#endif
T
Thunderbrook 已提交
654
    VLOG(3) << "GpuPs build hbmps done";
Y
yaoxuefeng 已提交
655
  };
656

T
Thunderbrook 已提交
657
  if (multi_mf_dim_) {
658 659 660
    for (int i = 0; i < thread_keys_shard_num_; i++) {
      for (int j = 0; j < multi_mf_dim_; j++) {
        threads[i * multi_mf_dim_ + j] =
Y
yaoxuefeng 已提交
661
            std::thread(build_pull_dynamic_mf_func, i, j);
662 663
      }
    }
T
Thunderbrook 已提交
664 665 666 667 668 669 670 671 672 673 674 675 676 677
    for (std::thread& t : threads) {
      t.join();
    }
  } else {
    for (int i = 0; i < thread_keys_shard_num_; i++) {
      for (int j = 0; j < device_num; j++) {
        task_futures.emplace_back(
            hbm_thread_pool_[i]->enqueue(prepare_dev_value_func, j, i));
      }
    }
    for (auto& f : task_futures) {
      f.wait();
    }
    task_futures.clear();
Y
yaoxuefeng 已提交
678 679
  }
  timeline.Pause();
T
Thunderbrook 已提交
680
  VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec()
681
          << " seconds.";
Y
yaoxuefeng 已提交
682 683
}

684
void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
685
  int device_num = heter_devices_.size();
Y
yaoxuefeng 已提交
686 687
  platform::Timer timeline;
  timeline.Start();
T
Thunderbrook 已提交
688

689
  std::vector<size_t> feature_keys_count(device_num);
T
Thunderbrook 已提交
690
  size_t size_max = 0;
Y
yaoxuefeng 已提交
691 692 693 694 695 696 697

  for (int i = 0; i < device_num; i++) {
    for (int j = 0; j < multi_mf_dim_; j++) {
      feature_keys_count[i] += gpu_task->device_dim_ptr_[i][j].size();
      VLOG(1) << i << " card with dynamic mf dim: " << index_dim_vec_[j]
              << " dim index: " << j << " contains feasign nums: "
              << gpu_task->device_dim_ptr_[i][j].size();
698
    }
Y
yaoxuefeng 已提交
699 700 701
    VLOG(1) << i << " card with dynamic mf contains feasign nums total: "
            << feature_keys_count[i];
    size_max = std::max(size_max, feature_keys_count[i]);
T
Thunderbrook 已提交
702
  }
Y
yaoxuefeng 已提交
703

T
Thunderbrook 已提交
704
  if (HeterPs_) {
705 706
    delete HeterPs_;
    HeterPs_ = nullptr;
T
Thunderbrook 已提交
707
  }
708
  if (size_max <= 0) {
709
    VLOG(0) << "Skip build gpu ps cause feasign nums = " << size_max;
710 711
    return;
  }
712
  std::vector<std::thread> threads(device_num);
D
danleifeng 已提交
713
  auto accessor_wrapper_ptr =
D
danleifeng 已提交
714
      GlobalAccessorFactory::GetInstance().GetAccessorWrapper();
D
danleifeng 已提交
715 716
  HeterPs_ = HeterPsBase::get_instance(
      size_max, resource_, fleet_config_, accessor_class_, optimizer_type_);
F
Fan Zhang 已提交
717
#ifdef PADDLE_WITH_CUDA
718
  HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
D
danleifeng 已提交
719 720
  HeterPs_->set_sparse_sgd(optimizer_config_);
  HeterPs_->set_embedx_sgd(optimizer_config_);
F
Fan Zhang 已提交
721
#endif
Z
zmxdream 已提交
722

D
danleifeng 已提交
723 724
  auto build_dymf_mem_pool = [this, &gpu_task, &accessor_wrapper_ptr](int i,
                                                                      int j) {
Y
yaoxuefeng 已提交
725 726
    this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_);
    int mf_dim = this->index_dim_vec_[j];
D
danleifeng 已提交
727 728 729
    VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim
            << " feature_value_size:"
            << accessor_wrapper_ptr->GetFeatureValueSize(mf_dim);
Y
yaoxuefeng 已提交
730
    size_t feature_value_size =
D
danleifeng 已提交
731
        accessor_wrapper_ptr->GetFeatureValueSize(mf_dim);
Y
yaoxuefeng 已提交
732 733 734 735 736 737
    auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
    auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j];
    size_t len = device_dim_keys.size();
    CHECK(len == device_dim_ptrs.size());
    this->mem_pools_[i * this->multi_mf_dim_ + j] =
        new MemoryPool(len, feature_value_size);
Z
zmxdream 已提交
738
  };
D
danleifeng 已提交
739 740
  auto build_dymf_hbm_pool = [this, &gpu_task, &accessor_wrapper_ptr](int i,
                                                                      int j) {
Z
zmxdream 已提交
741 742 743 744
    auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
    size_t len = device_dim_keys.size();
    int mf_dim = this->index_dim_vec_[j];
    size_t feature_value_size =
D
danleifeng 已提交
745
        accessor_wrapper_ptr->GetFeatureValueSize(mf_dim);
Z
zmxdream 已提交
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767

    auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j];
    platform::CUDADeviceGuard guard(resource_->dev_id(i));
    this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool);
    auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];

    this->HeterPs_->build_ps(i,
                             device_dim_keys.data(),
                             cur_pool->mem(),
                             len,
                             feature_value_size,
                             500000,
                             2);
    if (device_dim_keys.size() > 0) {
      VLOG(3) << "show table: " << i
              << " table kv size: " << device_dim_keys.size()
              << "dim: " << mf_dim << " len: " << len;
      HeterPs_->show_one_table(i);
    }
    delete mem_pool;
  };
  int thread_num = 16;
D
danleifeng 已提交
768 769 770 771
  auto build_dynamic_mf_func = [this,
                                &gpu_task,
                                thread_num,
                                &accessor_wrapper_ptr](int i, int j, int z) {
Z
zmxdream 已提交
772 773 774 775 776 777 778 779 780
    // this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_);
    int mf_dim = this->index_dim_vec_[j];
    VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim;
    auto& device_dim_keys = gpu_task->device_dim_keys_[i][j];
    auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j];
    size_t len = device_dim_keys.size();
    CHECK(len == device_dim_ptrs.size());
    // this->mem_pools_[i * this->multi_mf_dim_ + j] =
    //    new MemoryPool(len, feature_value_size);
Y
yaoxuefeng 已提交
781
    auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j];
Z
zmxdream 已提交
782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800

    // ============ add for multi-thread ================
    size_t len_per_thread = len / thread_num;
    size_t remain = len % thread_num;
    size_t left = 0, right = 0;

    size_t real_len = len_per_thread;
    if ((size_t)z < remain) real_len++;

    if ((size_t)z < remain) {
      left = z * (len_per_thread + 1);
      right = left + real_len;
    } else {
      left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread;
      right = left + real_len;
    }
    // ============ add for multi-thread ================

    for (size_t k = left; k < right; k++) {
D
danleifeng 已提交
801 802
#ifdef PADDLE_WITH_PSLIB
      float* val = (float*)(mem_pool->mem_address(k));
Y
yaoxuefeng 已提交
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
      float* ptr_val = device_dim_ptrs[k]->data();
      size_t dim = device_dim_ptrs[k]->size();
      val->delta_score =
          ptr_val[paddle::ps::DownpourCtrDymfAccessor::
                      DownpourCtrDymfFeatureValue::delta_score_index()];
      val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
                              DownpourCtrDymfFeatureValue::show_index()];
      val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
                             DownpourCtrDymfFeatureValue::click_index()];
      val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor::
                                  DownpourCtrDymfFeatureValue::slot_index()]);
      val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor::
                            DownpourCtrDymfFeatureValue::embed_w_index()];
      val->lr_g2sum =
          ptr_val[paddle::ps::DownpourCtrDymfAccessor::
                      DownpourCtrDymfFeatureValue::embed_g2sum_index()];
Y
yaoxuefeng 已提交
819
      // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor
Y
yaoxuefeng 已提交
820 821 822 823 824 825 826 827 828 829 830 831 832 833
      ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue::
                  mf_dim_index()] = float(mf_dim);
      val->mf_dim = mf_dim;
      if (dim > 8) {  // CpuPS alreay expand as mf_dim
        val->mf_size = mf_dim + 1;
        for (int x = 0; x < val->mf_dim + 1; x++) {
          val->mf[x] = ptr_val[x + 8];
        }
      } else {
        val->mf_size = 0;
        for (int x = 0; x < val->mf_dim + 1; x++) {
          val->mf[x] = 0;
        }
      }
D
danleifeng 已提交
834 835 836 837 838 839
#endif
#ifdef PADDLE_WITH_PSCORE
      void* val = mem_pool->mem_address(k);
      accessor_wrapper_ptr->BuildFill(
          val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim);
#endif
Y
yaoxuefeng 已提交
840
    }
Z
zmxdream 已提交
841
  };
Y
yaoxuefeng 已提交
842

Z
zmxdream 已提交
843 844 845 846 847 848
  threads.resize(device_num * multi_mf_dim_);
  for (int i = 0; i < device_num; i++) {
    for (int j = 0; j < multi_mf_dim_; j++) {
      threads[i + j * device_num] = std::thread(build_dymf_mem_pool, i, j);
    }
  }
Y
yaoxuefeng 已提交
849

Z
zmxdream 已提交
850 851 852 853
  for (std::thread& t : threads) {
    t.join();
  }
  threads.clear();
Y
yaoxuefeng 已提交
854

Z
zmxdream 已提交
855 856 857 858 859 860 861 862
  // multi-thread process
  threads.resize(device_num * multi_mf_dim_ * thread_num);
  for (int i = 0; i < device_num; i++) {
    for (int j = 0; j < multi_mf_dim_; j++) {
      for (int k = 0; k < thread_num; k++) {
        threads[(i + j * device_num) * thread_num + k] =
            std::thread(build_dynamic_mf_func, i, j, k);
      }
Y
yaoxuefeng 已提交
863
    }
Z
zmxdream 已提交
864 865 866 867 868
  }
  for (std::thread& t : threads) {
    t.join();
  }
  threads.clear();
Y
yaoxuefeng 已提交
869 870 871
  threads.resize(device_num * multi_mf_dim_);
  for (int i = 0; i < device_num; i++) {
    for (int j = 0; j < multi_mf_dim_; j++) {
Z
zmxdream 已提交
872
      threads[i + j * device_num] = std::thread(build_dymf_hbm_pool, i, j);
Y
yaoxuefeng 已提交
873
    }
Y
yaoxuefeng 已提交
874 875 876
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
877
  }
Z
zmxdream 已提交
878 879
  threads.clear();

T
Thunderbrook 已提交
880
  timeline.Pause();
881
  VLOG(0) << "GpuPs build table total costs: " << timeline.ElapsedSec()
T
Thunderbrook 已提交
882
          << " s.";
883 884 885 886 887 888 889 890 891 892 893 894 895 896
}

void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
  platform::Timer timer;
  VLOG(3) << "Begin LoadIntoMemory(), dataset[" << dataset_ << "]";
  timer.Start();
  dataset_->LoadIntoMemory();
  timer.Pause();
  VLOG(0) << "LoadIntoMemory cost: " << timer.ElapsedSec() << "s";

  // local shuffle
  if (is_shuffle) {
    dataset_->LocalShuffle();
  }
Y
yaoxuefeng 已提交
897
  InitSlotInfo();
D
danleifeng 已提交
898
  gpu_graph_mode_ = dataset_->GetGpuGraphMode();
899 900
  std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
  gpu_task->Reset();
Y
yaoxuefeng 已提交
901

902
  data_ready_channel_->Put(gpu_task);
Y
yaoxuefeng 已提交
903

904 905 906 907 908
  VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]";
}

void PSGPUWrapper::start_build_thread() {
  running_ = true;
909
  VLOG(3) << "start build CPU ps thread.";
910
  pre_build_threads_ = std::thread([this] { pre_build_thread(); });
911 912
}

913 914
void PSGPUWrapper::pre_build_thread() {
  // prebuild: process load_data
915 916 917 918 919
  while (running_) {
    std::shared_ptr<HeterContext> gpu_task = nullptr;
    if (!data_ready_channel_->Get(gpu_task)) {
      continue;
    }
920
    VLOG(3) << "thread PreBuildTask start.";
921 922 923
    platform::Timer timer;
    timer.Start();
    // build cpu ps data process
924
    PreBuildTask(gpu_task);
925
    timer.Pause();
926
    VLOG(0) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec()
T
Thunderbrook 已提交
927
            << " s";
928 929 930 931 932
    buildcpu_ready_channel_->Put(gpu_task);
  }
  VLOG(3) << "build cpu thread end";
}

933 934 935 936 937 938 939 940 941 942
void PSGPUWrapper::build_task() {
  // build_task: build_pull + build_gputask
  std::shared_ptr<HeterContext> gpu_task = nullptr;
  // train end, gpu free
  if (!gpu_free_channel_->Get(gpu_task)) {
    return;
  }
  // ins and pre_build end
  if (!buildcpu_ready_channel_->Get(gpu_task)) {
    return;
943
  }
944

945
  VLOG(0) << "BuildPull start.";
946 947 948 949 950
  platform::Timer timer;
  timer.Start();
  BuildPull(gpu_task);
  BuildGPUTask(gpu_task);
  timer.Pause();
951
  VLOG(0) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec()
952 953 954
          << "s";

  current_task_ = gpu_task;
955 956 957 958 959 960 961 962 963
}

void PSGPUWrapper::BeginPass() {
  platform::Timer timer;
  timer.Start();
  if (current_task_) {
    PADDLE_THROW(
        platform::errors::Fatal("[BeginPass] current task is not ended."));
  }
964

D
danleifeng 已提交
965
  debug_gpu_memory_info("befor build task");
966
  build_task();
D
danleifeng 已提交
967
  debug_gpu_memory_info("after build task");
968
  timer.Pause();
969 970 971 972 973

  if (current_task_ == nullptr) {
    PADDLE_THROW(platform::errors::Fatal(
        "[BeginPass] after build_task, current task is not null."));
  }
D
danleifeng 已提交
974 975 976 977 978 979 980
  if (FLAGS_gpugraph_dedup_pull_push_mode) {
    VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec()
            << "s, enable pull push dedup mode="
            << FLAGS_gpugraph_dedup_pull_push_mode;
  } else {
    VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s";
  }
981 982 983 984 985 986 987 988 989 990 991
}

void PSGPUWrapper::EndPass() {
  if (!current_task_) {
    PADDLE_THROW(
        platform::errors::Fatal("[EndPass] current task has been ended."));
  }
  platform::Timer timer;
  timer.Start();
  size_t keysize_max = 0;
  // in case of feasign_num = 0, skip dump_to_cpu
Y
yaoxuefeng 已提交
992

993
  for (size_t i = 0; i < heter_devices_.size(); i++) {
Y
yaoxuefeng 已提交
994 995 996 997 998
    for (int j = 0; j < multi_mf_dim_; j++) {
      keysize_max =
          std::max(keysize_max, current_task_->device_dim_keys_[i][j].size());
    }
  }
999
  int thread_num = 8;
D
danleifeng 已提交
1000
  auto accessor_wrapper_ptr =
D
danleifeng 已提交
1001
      GlobalAccessorFactory::GetInstance().GetAccessorWrapper();
D
danleifeng 已提交
1002 1003
  auto dump_pool_to_cpu_func = [this, thread_num, &accessor_wrapper_ptr](
                                   int i, int j, int z) {
Y
yaoxuefeng 已提交
1004 1005 1006 1007
    PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i)));
    auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j];
    auto& device_keys = this->current_task_->device_dim_keys_[i][j];
    size_t len = device_keys.size();
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021
    // ====== multi-thread process feasign================
    int len_per_thread = len / thread_num;
    int remain = len % thread_num;
    int left = -1, right = -1;
    int real_len = len_per_thread;
    if (z < remain) real_len++;
    if (z < remain) {
      left = z * (len_per_thread + 1);
      right = left + real_len;
    } else {
      left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread;
      right = left + real_len;
    }
    // ============ multi-thread process feasign============
Y
yaoxuefeng 已提交
1022 1023
    int mf_dim = this->index_dim_vec_[j];
    size_t feature_value_size =
D
danleifeng 已提交
1024 1025 1026 1027
        accessor_wrapper_ptr->GetFeatureValueSize(mf_dim);
    VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim
            << " key_len :" << len
            << " feature_value_size:" << feature_value_size;
1028 1029
    char* test_build_values = (char*)malloc(feature_value_size * real_len);
    uint64_t offset = left * feature_value_size;
1030 1031 1032 1033
    cudaMemcpy(test_build_values,
               hbm_pool->mem() + offset,
               feature_value_size * real_len,
               cudaMemcpyDeviceToHost);
Y
yaoxuefeng 已提交
1034 1035
    CHECK(len == hbm_pool->capacity());
    uint64_t unuse_key = std::numeric_limits<uint64_t>::max();
1036
    for (int i = left; i < right; ++i) {
Y
yaoxuefeng 已提交
1037 1038 1039
      if (device_keys[i] == unuse_key) {
        continue;
      }
1040
      size_t local_offset = (i - left) * feature_value_size;
D
danleifeng 已提交
1041
      float* gpu_val = (float*)(test_build_values + local_offset);
1042
#ifdef PADDLE_WITH_PSLIB
D
danleifeng 已提交
1043
      // TODO: PSLIB DumpFill
D
danleifeng 已提交
1044 1045 1046 1047
#endif
#ifdef PADDLE_WITH_PSCORE
      accessor_wrapper_ptr->DumpFill(gpu_val, cpu_table_accessor_, mf_dim);
#endif
Y
yaoxuefeng 已提交
1048 1049 1050 1051 1052 1053
    }
    free(test_build_values);
  };
  if (multi_mf_dim_) {
    VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_;
    size_t device_num = heter_devices_.size();
1054
    std::vector<std::thread> threads(device_num * multi_mf_dim_ * thread_num);
Y
yaoxuefeng 已提交
1055 1056
    for (size_t i = 0; i < device_num; i++) {
      for (int j = 0; j < multi_mf_dim_; j++) {
1057 1058 1059 1060
        for (int k = 0; k < thread_num; k++) {
          threads[(i + j * device_num) * thread_num + k] =
              std::thread(dump_pool_to_cpu_func, i, j, k);
        }
Y
yaoxuefeng 已提交
1061 1062 1063 1064 1065
      }
    }
    for (std::thread& t : threads) {
      t.join();
    }
1066 1067 1068 1069
  }
  if (keysize_max != 0) {
    HeterPs_->end_pass();
  }
1070

Y
yaoxuefeng 已提交
1071 1072 1073
  for (size_t i = 0; i < hbm_pools_.size(); i++) {
    delete hbm_pools_[i];
  }
1074
  gpu_task_pool_.Push(current_task_);
1075 1076 1077
  current_task_ = nullptr;
  gpu_free_channel_->Put(current_task_);
  timer.Pause();
Y
yaoxuefeng 已提交
1078
  VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s";
T
Thunderbrook 已提交
1079 1080 1081 1082 1083 1084 1085 1086
}

void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
                              const int table_id,
                              const std::vector<const uint64_t*>& keys,
                              const std::vector<float*>& values,
                              const std::vector<int64_t>& slot_lengths,
                              const int hidden_size) {
D
danleifeng 已提交
1087 1088
  VLOG(0) << "Warning:: recommand use pull_gpups_sparse op instead. This "
             "PullSparse is not used.";
Y
yaoxuefeng 已提交
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
}

void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
                              const int table_id,
                              const std::vector<const uint64_t*>& keys,
                              const std::vector<float*>& values,
                              const std::vector<int64_t>& slot_lengths,
                              const std::vector<int>& slot_dim,
                              const int hidden_size) {
  VLOG(3) << "Begine Gpu Ps PullSparse";
  platform::Timer all_timer;
  platform::Timer pull_gpups_timer;
  all_timer.Start();

D
danleifeng 已提交
1103
  auto accessor_wrapper_ptr =
D
danleifeng 已提交
1104 1105 1106
      GlobalAccessorFactory::GetInstance().GetAccessorWrapper();
  size_t feature_value_size =
      accessor_wrapper_ptr->GetPullValueSize(max_mf_dim_);
D
danleifeng 已提交
1107
  VLOG(3) << "PullSparse max_dim:" << max_mf_dim_
D
danleifeng 已提交
1108
          << " pull_feature_value_size:" << pull_type_size_;
Y
yaoxuefeng 已提交
1109 1110 1111 1112 1113

  if (platform::is_cpu_place(place)) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Warning:: CPUPlace is not supported in GpuPs now."));
  } else if (platform::is_gpu_place(place)) {
D
danleifeng 已提交
1114
#ifdef PADDLE_WITH_CUDA
Y
yaoxuefeng 已提交
1115 1116
    int device_id = place.GetDeviceId();
    int devid_2_index = HeterPs_->get_index_by_devid(device_id);
D
danleifeng 已提交
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297
    if (FLAGS_gpugraph_dedup_pull_push_mode > 0) {
      auto& dev = device_caches_[devid_2_index];
      int slot_num = static_cast<int>(slot_lengths.size());
      std::vector<int64_t> slot_lengths_lod;
      slot_lengths_lod.reserve(slot_num + 1);
      slot_lengths_lod.push_back(0);

      int64_t total_length = 0;
      for (int i = 0; i < slot_num; ++i) {
        total_length += slot_lengths[i];
        slot_lengths_lod.push_back(total_length);
      }
      dev.total_key_length = total_length;
      VLOG(3) << "[" << device_id << "]Begin copy keys, key_num["
              << total_length << "] dedup mode";

      auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                        platform::DeviceContextPool::Instance().Get(place))
                        ->stream();

      uint64_t* total_keys = dev.keys_tensor.mutable_data<uint64_t>(
          (total_length * 3) * sizeof(uint64_t), place);

      int* gpu_slot_dims = dev.dims_tensor.mutable_data<int>(
          slot_dim.size() * sizeof(int), place);
      uint64_t** gpu_keys = dev.keys_ptr_tensor.mutable_data<uint64_t*>(
          keys.size() * sizeof(uint64_t*), place);

      int64_t* slot_lens = dev.slot_lens.mutable_data<int64_t>(
          (slot_num + 1) * sizeof(int64_t), place);
      cudaMemcpyAsync(gpu_keys,
                      keys.data(),
                      keys.size() * sizeof(uint64_t*),
                      cudaMemcpyHostToDevice,
                      stream);
      cudaMemcpyAsync(slot_lens,
                      slot_lengths_lod.data(),
                      slot_lengths_lod.size() * sizeof(int64_t),
                      cudaMemcpyHostToDevice,
                      stream);

      cudaMemcpyAsync(gpu_slot_dims,
                      slot_dim.data(),
                      slot_dim.size() * sizeof(int),
                      cudaMemcpyHostToDevice,
                      stream);
      float** gpu_values = dev.values_ptr_tensor.mutable_data<float*>(
          values.size() * sizeof(float*), place);
      cudaMemcpyAsync(gpu_values,
                      values.data(),
                      values.size() * sizeof(float*),
                      cudaMemcpyHostToDevice,
                      stream);

      int* key2slot = dev.keys2slot.mutable_data<int>(
          (total_length * 5) * sizeof(int), place);

      this->CopyKeys(place,
                     gpu_keys,
                     total_keys,
                     slot_lens,
                     slot_num,
                     static_cast<int>(total_length),
                     key2slot);

      uint32_t* d_restore_idx =
          reinterpret_cast<uint32_t*>(&key2slot[total_length]);
      uint32_t* d_sorted_idx =
          reinterpret_cast<uint32_t*>(&d_restore_idx[total_length]);
      uint32_t* d_offset =
          reinterpret_cast<uint32_t*>(&d_sorted_idx[total_length]);
      uint32_t* d_merged_cnts =
          reinterpret_cast<uint32_t*>(&d_offset[total_length]);
      uint64_t* d_merged_keys =
          reinterpret_cast<uint64_t*>(&total_keys[total_length]);
      uint64_t* d_sorted_keys =
          reinterpret_cast<uint64_t*>(&d_merged_keys[total_length]);

      int dedup_size = HeterPs_->dedup_keys_and_fillidx(
          devid_2_index,
          static_cast<int>(total_length),
          total_keys,     // input
          d_merged_keys,  // output
          d_sorted_keys,  // sort keys
          d_restore_idx,  // pull fill idx
          d_sorted_idx,   // sort old idx
          d_offset,       // offset
          d_merged_cnts,
          FLAGS_gpugraph_dedup_pull_push_mode & 0x02);
      //      printf("device %d, end dedup_keys_and_fillidx total %d, "
      //              "dedup_size %d, slot num: %d, value size: %d\n",
      //             device_id, int(total_length), dedup_size, slot_num,
      //             int(feature_value_size));

      PADDLE_ENFORCE_GT(dedup_size,
                        0,
                        platform::errors::PreconditionNotMet(
                            "dedup keys need more than zero failed in BoxPS."));
      dev.dedup_key_length = dedup_size;

      int64_t total_bytes = dedup_size * feature_value_size;
      float* total_values_gpu =
          dev.pull_push_tensor.mutable_data<float>(total_bytes, place);
      pull_gpups_timer.Start();
      HeterPs_->pull_sparse(
          devid_2_index, d_merged_keys, total_values_gpu, dedup_size);

      // values.size() not sure equal slot_num
      accessor_wrapper_ptr->CopyForPull(place,
                                        total_keys,
                                        gpu_values,
                                        total_values_gpu,
                                        slot_lens,
                                        key2slot,
                                        max_mf_dim_ + 3,
                                        total_length,
                                        gpu_slot_dims,
                                        d_restore_idx,
                                        feature_value_size);
    } else {
      size_t total_length =
          std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
      auto buf = memory::Alloc(place, total_length * feature_value_size);
      float* total_values_gpu = reinterpret_cast<float*>(buf->ptr());
      VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
      LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
      uint64_t* total_keys =
          reinterpret_cast<uint64_t*>(total_keys_tensor.mutable_data<int64_t>(
              {int64_t(total_length), 1}, place));
      // construct slot_level lod info
      auto slot_lengths_lod = slot_lengths;
      for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
        slot_lengths_lod[i] += slot_lengths_lod[i - 1];
      }
      auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*));
      auto buf_length =
          memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
      uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
      int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
      cudaMemcpy(gpu_keys,
                 keys.data(),
                 keys.size() * sizeof(uint64_t*),
                 cudaMemcpyHostToDevice);
      cudaMemcpy(gpu_len,
                 slot_lengths_lod.data(),
                 slot_lengths.size() * sizeof(int64_t),
                 cudaMemcpyHostToDevice);

      auto buf_dim = memory::Alloc(place, slot_dim.size() * sizeof(int));
      int* gpu_dim = reinterpret_cast<int*>(buf_dim->ptr());
      cudaMemcpy(gpu_dim,
                 slot_dim.data(),
                 slot_dim.size() * sizeof(int),
                 cudaMemcpyHostToDevice);

      this->CopyKeys(place,
                     gpu_keys,
                     total_keys,
                     gpu_len,
                     static_cast<int>(slot_lengths.size()),
                     static_cast<int>(total_length));
      VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index
              << " len: " << total_length;

      pull_gpups_timer.Start();
      HeterPs_->pull_sparse(
          devid_2_index, total_keys, total_values_gpu, total_length);

      VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
              << "]";

      accessor_wrapper_ptr->CopyForPull(place,
                                        gpu_keys,
                                        values,
                                        total_values_gpu,
                                        gpu_len,
                                        static_cast<int>(slot_lengths.size()),
                                        hidden_size,
                                        total_length,
                                        gpu_dim,
                                        feature_value_size);
Y
yaoxuefeng 已提交
1298 1299
    }
    pull_gpups_timer.Pause();
D
danleifeng 已提交
1300
#endif
F
Fan Zhang 已提交
1301 1302
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU_KP
D
danleifeng 已提交
1303 1304 1305 1306 1307 1308
    VLOG(3) << "Begine Xpu Ps PullSparse";
    size_t total_length =
        std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
    FeatureValue* total_values_gpu = nullptr;
    xpu_malloc(reinterpret_cast<void**>(&total_values_gpu),
               total_length * feature_value_size);
F
Fan Zhang 已提交
1309 1310 1311 1312
    VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
    int device_id = place.GetDeviceId();
    int devid_2_index = HeterPs_->get_index_by_devid(device_id);
    LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
D
danleifeng 已提交
1313 1314 1315
    uint64_t* total_keys =
        reinterpret_cast<uint64_t*>(total_keys_tensor.mutable_data<int64_t>(
            {int64_t(total_length), 1}, place));
F
Fan Zhang 已提交
1316 1317 1318 1319 1320 1321 1322

    // construct slot_level lod info
    auto slot_lengths_lod = slot_lengths;
    for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
      slot_lengths_lod[i] += slot_lengths_lod[i - 1];
    }

F
Fan Zhang 已提交
1323 1324 1325 1326 1327
    auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*));
    auto buf_length =
        memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
    uint64_t** xpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
    int64_t* xpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
1328 1329
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_memcpy(xpu_keys,
                                          keys.data(),
F
Fan Zhang 已提交
1330 1331
                                          keys.size() * sizeof(uint64_t*),
                                          XPU_HOST_TO_DEVICE));
1332 1333
    PADDLE_ENFORCE_XPU_SUCCESS(xpu_memcpy(xpu_len,
                                          slot_lengths_lod.data(),
F
Fan Zhang 已提交
1334 1335 1336
                                          slot_lengths.size() * sizeof(int64_t),
                                          XPU_HOST_TO_DEVICE));

1337 1338 1339 1340
    this->CopyKeys(place,
                   xpu_keys,
                   total_keys,
                   xpu_len,
F
Fan Zhang 已提交
1341 1342 1343 1344 1345
                   static_cast<int>(slot_lengths.size()),
                   static_cast<int>(total_length));
    VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index
            << " len: " << total_length;
    pull_gpups_timer.Start();
1346 1347 1348
    HeterPs_->pull_sparse(devid_2_index,
                          total_keys,
                          total_values_gpu,
F
Fan Zhang 已提交
1349 1350 1351 1352 1353
                          static_cast<int>(total_length));
    pull_gpups_timer.Pause();

    VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
            << "]";
D
danleifeng 已提交
1354 1355 1356 1357 1358 1359 1360 1361
    accessor_wrapper_ptr->CopyForPull(place,
                                      xpu_keys,
                                      values,
                                      total_values_gpu,
                                      xpu_len,
                                      static_cast<int>(slot_lengths.size()),
                                      hidden_size,
                                      total_length,
D
danleifeng 已提交
1362
                                      feature_value_size);
F
Fan Zhang 已提交
1363
#endif
T
Thunderbrook 已提交
1364 1365
  } else {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
F
Fan Zhang 已提交
1366
        "GpuPs/XpuPs: PullSparse Only Support CUDAPlace or XPUPlace Now."));
T
Thunderbrook 已提交
1367 1368
  }
  all_timer.Pause();
1369
  VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec()
T
Thunderbrook 已提交
1370 1371 1372 1373 1374 1375 1376 1377 1378 1379
          << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec()
          << " s";
  VLOG(3) << "End PullSparse";
}

void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place,
                                  const int table_id,
                                  const std::vector<const uint64_t*>& keys,
                                  const std::vector<const float*>& grad_values,
                                  const std::vector<int64_t>& slot_lengths,
1380 1381
                                  const int hidden_size,
                                  const int batch_size) {
T
Thunderbrook 已提交
1382 1383 1384
  platform::Timer all_timer;
  platform::Timer push_gpups_timer;
  all_timer.Start();
D
danleifeng 已提交
1385
  auto accessor_wrapper_ptr =
D
danleifeng 已提交
1386
      GlobalAccessorFactory::GetInstance().GetAccessorWrapper();
D
danleifeng 已提交
1387
  size_t grad_value_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_);
D
danleifeng 已提交
1388

T
Thunderbrook 已提交
1389 1390 1391 1392
  if (platform::is_cpu_place(place)) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Warning:: CPUPlace is not supported in GPUPS now."));
  } else if (platform::is_gpu_place(place)) {
F
Fan Zhang 已提交
1393
#ifdef PADDLE_WITH_CUDA
1394
    int device_id = place.GetDeviceId();
T
Thunderbrook 已提交
1395
    int devid_2_index = HeterPs_->get_index_by_devid(device_id);
D
danleifeng 已提交
1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416
    if (FLAGS_gpugraph_dedup_pull_push_mode > 0) {
      auto& dev = device_caches_[devid_2_index];
      int64_t total_length = dev.total_key_length;
      VLOG(3) << "Begin push sparse, key_num[" << total_length
              << "] dedup mode, device:" << device_id << ", index"
              << devid_2_index;
      auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                        platform::DeviceContextPool::Instance().Get(place))
                        ->stream();
      uint64_t* total_keys = dev.keys_tensor.data<uint64_t>();
      int* slot_dims = dev.dims_tensor.data<int>();
      int slot_num = static_cast<int>(slot_lengths.size());
      if (!dev.d_slot_vector.IsInitialized()) {
        int* buf_slot_vector =
            dev.d_slot_vector.mutable_data<int>(slot_num * sizeof(int), place);
        cudaMemcpyAsync(buf_slot_vector,
                        slot_vector_.data(),
                        slot_num * sizeof(int),
                        cudaMemcpyHostToDevice,
                        stream);
      }
T
Thunderbrook 已提交
1417

D
danleifeng 已提交
1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514
      const int64_t* slot_lens = dev.slot_lens.data<int64_t>();
      const int* d_slot_vector = dev.d_slot_vector.data<int>();
      const int* key2slot = dev.keys2slot.data<int>();
      float** gpu_values = dev.values_ptr_tensor.data<float*>();
      cudaMemcpyAsync(gpu_values,
                      grad_values.data(),
                      grad_values.size() * sizeof(float*),
                      cudaMemcpyHostToDevice,
                      stream);

      uint64_t* d_merged_keys = &total_keys[total_length];

      int64_t dedup_size = dev.dedup_key_length;
      int64_t total_bytes = dedup_size * grad_value_size;
      float* total_grad_values_gpu =
          dev.pull_push_tensor.mutable_data<float>(total_bytes, place);
      // dedup rate more than 3
      if (total_length > dedup_size * 3) {
        const uint32_t* d_restore_idx =
            reinterpret_cast<const uint32_t*>(&key2slot[total_length]);
        accessor_wrapper_ptr->CopyForPush(place,
                                          total_keys,
                                          gpu_values,
                                          total_grad_values_gpu,
                                          d_slot_vector,
                                          slot_lens,
                                          max_mf_dim_ + 3,
                                          total_length,
                                          dedup_size,
                                          batch_size,
                                          slot_dims,
                                          key2slot,
                                          d_restore_idx,
                                          grad_value_size);
      } else {
        const uint32_t* d_sorted_idx =
            reinterpret_cast<const uint32_t*>(&key2slot[total_length * 2]);
        const uint32_t* d_offset =
            reinterpret_cast<const uint32_t*>(&d_sorted_idx[total_length]);
        const uint32_t* d_merged_cnts =
            reinterpret_cast<const uint32_t*>(&d_offset[total_length]);
        accessor_wrapper_ptr->CopyForPush(place,
                                          d_merged_keys,
                                          gpu_values,
                                          total_grad_values_gpu,
                                          d_slot_vector,
                                          slot_lens,
                                          max_mf_dim_ + 3,
                                          total_length,
                                          dedup_size,
                                          batch_size,
                                          slot_dims,
                                          key2slot,
                                          d_sorted_idx,
                                          d_offset,
                                          d_merged_cnts,
                                          grad_value_size);
      }

      push_gpups_timer.Start();
      HeterPs_->push_sparse(devid_2_index,
                            d_merged_keys,
                            total_grad_values_gpu,
                            static_cast<int>(dedup_size));
    } else {
      int64_t total_length =
          std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
      VLOG(3) << "Begin GPUPS PushSparseGrad";

      auto buf = memory::Alloc(place, total_length * grad_value_size);
      VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_
              << "grad_value_size:" << grad_value_size;
      float* total_grad_values_gpu = reinterpret_cast<float*>(buf->ptr());

      LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
      uint64_t* total_keys =
          reinterpret_cast<uint64_t*>(total_keys_tensor.data<int64_t>());
      VLOG(3) << "Begin copy grad tensor to gpups struct";

      accessor_wrapper_ptr->CopyForPush(place,
                                        grad_values,
                                        total_grad_values_gpu,
                                        slot_lengths,
                                        total_length,
                                        batch_size,
                                        grad_value_size,
                                        slot_vector_,
                                        slot_mf_dim_vector_);

      VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index
              << " len: " << total_length;
      push_gpups_timer.Start();
      HeterPs_->push_sparse(devid_2_index,
                            total_keys,
                            total_grad_values_gpu,
                            static_cast<int>(total_length));
    }
T
Thunderbrook 已提交
1515
    push_gpups_timer.Pause();
F
Fan Zhang 已提交
1516
#endif
F
Fan Zhang 已提交
1517
  } else if (platform::is_xpu_place(place)) {
F
Fan Zhang 已提交
1518
#ifdef PADDLE_WITH_XPU_KP
F
Fan Zhang 已提交
1519 1520
    int device_id = place.GetDeviceId();
    int devid_2_index = HeterPs_->get_index_by_devid(device_id);
D
danleifeng 已提交
1521 1522 1523 1524 1525 1526 1527 1528 1529
    int64_t total_length =
        std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
    VLOG(3) << "Begin GPUPS PushSparseGrad";

    auto buf = memory::Alloc(place, total_length * grad_value_size);
    VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_
            << "grad_value_size:" << grad_value_size;
    float* total_grad_values_gpu = reinterpret_cast<float*>(buf->ptr());
    LoDTensor& total_keys_tensor = keys_tensor[devid_2_index];
F
Fan Zhang 已提交
1530
    uint64_t* total_keys =
D
danleifeng 已提交
1531
        reinterpret_cast<uint64_t*>(total_keys_tensor.data<int64_t>());
F
Fan Zhang 已提交
1532
    VLOG(3) << "Begin copy grad tensor to xpups struct";
D
danleifeng 已提交
1533 1534 1535 1536 1537 1538 1539 1540
    accessor_wrapper_ptr->CopyForPush(place,
                                      grad_values,
                                      total_grad_values_gpu,
                                      slot_lengths,
                                      hidden_size,
                                      total_length,
                                      batch_size,
                                      slot_vector_);
F
Fan Zhang 已提交
1541 1542 1543 1544

    VLOG(3) << "Begin call PushSparseXPU in XPUPS, dev: " << devid_2_index
            << " len: " << total_length;
    push_gpups_timer.Start();
1545 1546 1547
    HeterPs_->push_sparse(devid_2_index,
                          total_keys,
                          total_grad_values_gpu,
F
Fan Zhang 已提交
1548 1549
                          static_cast<int>(total_length));
    push_gpups_timer.Pause();
F
Fan Zhang 已提交
1550
#endif
T
Thunderbrook 已提交
1551 1552 1553 1554 1555
  } else {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "GPUPS: PushSparseGrad Only Support CUDAPlace Now."));
  }
  all_timer.Pause();
Y
yaoxuefeng 已提交
1556 1557
  time_3 += all_timer.ElapsedSec();
  time_4 += push_gpups_timer.ElapsedSec();
1558
  VLOG(3) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
T
Thunderbrook 已提交
1559 1560 1561 1562 1563
          << " s, of which GPUPS cost: " << push_gpups_timer.ElapsedSec()
          << " s";
  VLOG(3) << "End PushSparseGrad";
}

D
danleifeng 已提交
1564
}  // namespace framework
T
Thunderbrook 已提交
1565 1566
}  // end namespace paddle
#endif