hashtable_kernel.cu 17.9 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16
#include <thread>
17

18 19
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

51 52 53
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
Y
yaoxuefeng 已提交
54 55
                              size_t len, char* pool, size_t feature_value_size,
                              int start_index) {
56 57 58 59 60 61 62
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
63 64
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
65 66 67 68 69
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

84 85 86
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
Y
yaoxuefeng 已提交
87
                                    char* vals, size_t len,
88 89
                                    size_t pull_feature_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
Y
yaoxuefeng 已提交
90
  // return;
91 92 93 94
  if (i < len) {
    auto it = table->find(keys[i]);

    if (it != table->end()) {
Y
yaoxuefeng 已提交
95
      uint64_t offset = i * pull_feature_value_size;
Y
yaoxuefeng 已提交
96
      FeatureValue* cur = (FeatureValue*)(vals + offset);
Y
yaoxuefeng 已提交
97
      FeatureValue& input = *(FeatureValue*)(it->second);
Y
yaoxuefeng 已提交
98 99 100 101 102 103 104 105 106 107 108 109
      cur->slot = input.slot;
      cur->show = input.show;
      cur->clk = input.clk;
      cur->mf_dim = input.mf_dim;
      cur->lr = input.lr;
      cur->mf_size = input.mf_size;
      cur->cpu_ptr = input.cpu_ptr;
      cur->delta_score = input.delta_score;
      cur->lr_g2sum = input.lr_g2sum;
      for (int j = 0; j < cur->mf_dim + 1; ++j) {
        cur->mf[j] = input.mf[j];
      }
Y
yaoxuefeng 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    } else {
      if (keys[i] != 0) {
        printf("warning::pull miss key: %d", keys[i]);
      }
      FeatureValue* cur = (FeatureValue*)(vals + i * pull_feature_value_size);
      cur->delta_score = 0;
      cur->show = 0;
      cur->clk = 0;
      cur->slot = -1;
      cur->lr = 0;
      cur->lr_g2sum = 0;
      cur->mf_size = 0;
      cur->mf_dim = 8;
      cur->cpu_ptr;
      for (int j = 0; j < cur->mf_dim + 1; j++) {
        cur->mf[j] = 0;
      }
127 128 129
    }
  }
}
130

T
Thunderbrook 已提交
131 132
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
133
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
134 135 136 137 138 139 140
                              const typename Table::key_type* const keys,
                              const GradType* const grads, size_t len,
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
141
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
142 143 144 145
    }
  }
}

146 147
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
148
                                    const OptimizerConfig& optimizer_config,
149 150 151 152 153 154 155 156
                                    const typename Table::key_type* const keys,
                                    const char* const grads, size_t len,
                                    Sgd sgd, size_t grad_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
Z
zmxdream 已提交
157
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
158
    } else {
Y
yaoxuefeng 已提交
159 160 161
      if (keys[i] != 0) {
        printf("warning::push miss key: %d", keys[i]);
      }
162 163 164 165
    }
  }
}

T
Thunderbrook 已提交
166 167 168
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
Z
zmxdream 已提交
169 170 171
  cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig));
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
172
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
173 174 175 176 177 178 179
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
}

Z
zmxdream 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
}

T
Thunderbrook 已提交
196 197 198 199 200 201
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
202
template <typename StreamType>
T
Thunderbrook 已提交
203
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
204
                                      size_t len, StreamType stream) {
T
Thunderbrook 已提交
205 206 207 208 209 210 211 212
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

213
template <typename KeyType, typename ValType>
214
template <typename StreamType>
215
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
216
                                      size_t len, StreamType stream) {
217 218 219 220 221 222 223 224
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
225
template <typename KeyType, typename ValType>
226
template <typename StreamType>
T
Thunderbrook 已提交
227 228
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
229
                                         StreamType stream) {
T
Thunderbrook 已提交
230 231 232 233 234 235 236 237
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

238
template <typename KeyType, typename ValType>
239
template <typename StreamType>
240
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
Y
yaoxuefeng 已提交
241 242
                                         char* pool, size_t feature_value_size,
                                         size_t start_index,
243
                                         StreamType stream) {
244 245 246 247 248 249
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
250
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
251 252
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
253 254
}

T
Thunderbrook 已提交
255
template <typename KeyType, typename ValType>
256 257
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
258
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
259
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
260 261 262
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
263 264 265 266 267 268 269 270 271 272 273 274

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
275
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
294
      }
T
Thunderbrook 已提交
295 296
#endif
#ifdef PADDLE_WITH_PSCORE
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
      auto* downpour_value =
          (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[2] = gpu_val.delta_score;
      cpu_val[3] = gpu_val.show;
      cpu_val[4] = gpu_val.clk;
      cpu_val[5] = gpu_val.lr;
      cpu_val[6] = gpu_val.lr_g2sum;
      cpu_val[0] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
315
      }
T
Thunderbrook 已提交
316
#endif
T
Thunderbrook 已提交
317 318 319 320 321 322 323 324 325 326
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
327 328
  }

T
Thunderbrook 已提交
329
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
330 331
}

T
Thunderbrook 已提交
332
template <typename KeyType, typename ValType>
333
template <typename GradType, typename Sgd, typename StreamType>
T
Thunderbrook 已提交
334 335
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
336
                                         Sgd sgd, StreamType stream) {
T
Thunderbrook 已提交
337 338 339 340
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
341 342
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
343 344
}

345
template <typename KeyType, typename ValType>
346
template <typename Sgd, typename StreamType>
347 348
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
349
                                         Sgd sgd, StreamType stream) {
350 351 352 353 354
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
Z
zmxdream 已提交
355 356
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd,
      push_grad_value_size_);
357 358
}

359
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
Y
yaoxuefeng 已提交
360
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
S
seemingwang 已提交
361
template class HashTable<long, int>;
T
Thunderbrook 已提交
362 363
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
364 365
template class HashTable<unsigned long, long>;
template class HashTable<unsigned long, long*>;
S
seemingwang 已提交
366
template class HashTable<long, long>;
367 368
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
369

370 371 372 373
template void
HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
    const unsigned long* d_keys, paddle::framework::FeatureValue* d_vals,
    size_t len, cudaStream_t stream);
374

Y
yaoxuefeng 已提交
375 376 377 378
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);

S
seemingwang 已提交
379 380 381 382
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
                                                      int* d_vals, size_t len,
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
383 384
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
385 386
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
387 388 389
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
                                                       long* d_vals, size_t len,
                                                       cudaStream_t stream);
390 391
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
392 393
template void HashTable<unsigned long, long>::get<cudaStream_t>(
    const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream);
394 395 396 397 398
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

399 400 401 402
template void
HashTable<unsigned long, paddle::framework::FeatureValue>::insert<cudaStream_t>(
    const unsigned long* d_keys, const paddle::framework::FeatureValue* d_vals,
    size_t len, cudaStream_t stream);
403

Y
yaoxuefeng 已提交
404 405 406 407 408
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
    insert<cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
                         size_t feature_value_size, size_t start_index,
                         cudaStream_t stream);

S
seemingwang 已提交
409 410 411 412
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
413 414 415 416
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
417

T
Thunderbrook 已提交
418 419 420
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
    const unsigned long* d_keys, const int* d_vals, size_t len,
    cudaStream_t stream);
421 422 423 424 425 426 427 428
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
    const long* d_keys, const unsigned long* d_vals, size_t len,
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
    const long* d_keys, const unsigned int* d_vals, size_t len,
    cudaStream_t stream);

429 430 431
template void HashTable<unsigned long, long>::insert<cudaStream_t>(
    const unsigned long* d_keys, const long* d_vals, size_t len,
    cudaStream_t stream);
432 433 434 435 436 437 438 439 440 441

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue,
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeaturePushValue* d_grads,
442
                  size_t len,
Y
yaoxuefeng 已提交
443 444 445 446 447
                  Optimizer<paddle::framework::FeatureValue,
                            paddle::framework::FeaturePushValue>
                      sgd,
                  cudaStream_t stream);

448 449 450 451 452 453 454 455 456 457
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
    update<Optimizer<paddle::framework::FeatureValue,
                     paddle::framework::FeaturePushValue>,
           cudaStream_t>(const unsigned long* d_keys, const char* d_grads,
                         size_t len,
                         Optimizer<paddle::framework::FeatureValue,
                                   paddle::framework::FeaturePushValue>
                             sgd,
                         cudaStream_t stream);

458 459 460 461 462 463 464 465 466 467 468 469
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
470 471 472
}  // end namespace framework
}  // end namespace paddle
#endif