hashtable_kernel.cu 19.1 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16
#include <thread>
17

18 19
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

51 52 53
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
54 55 56
                              size_t len,
                              char* pool,
                              size_t feature_value_size,
Y
yaoxuefeng 已提交
57
                              int start_index) {
58 59 60 61 62 63 64
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
65 66
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
67 68 69 70 71
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

86 87 88
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
89 90
                                    char* vals,
                                    size_t len,
91
                                    size_t pull_feature_value_size) {
Z
zmxdream 已提交
92 93
  const size_t i = blockIdx.x * blockDim.y + threadIdx.y;
  const size_t k = threadIdx.x;
94 95 96
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Y
yaoxuefeng 已提交
97
      uint64_t offset = i * pull_feature_value_size;
Y
yaoxuefeng 已提交
98
      FeatureValue* cur = (FeatureValue*)(vals + offset);
Y
yaoxuefeng 已提交
99
      FeatureValue& input = *(FeatureValue*)(it->second);
Z
zmxdream 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
      char* cur_p = (char*)cur;
      char* input_p = (char*)(&input);
      int len = 9 + input.mf_dim + 1;
      if (k == 3 || k == 6 || k == 7)
        *(int*)(cur_p + k * 4) = *(int*)(input_p + k * 4);
      else if (k < 8)
        *(float*)(cur_p + k * 4) = *(float*)(input_p + k * 4);
      else if (k == 8) {
        *(uint64_t*)(cur_p + k * 4) = *(uint64_t*)(input_p + k * 4);
      } else {
        int len_per_thread = (len - 9) / (blockDim.y - 9);
        int remain = (len - 9) % (blockDim.y - 9);
        int real_len = len_per_thread;
        if ((k - 9) < remain) real_len++;
        int left = -1, right = -1;
        if ((k - 9) < remain) {
          left = 9 + (k - 9) * (len_per_thread + 1);
          right = left + real_len;
        } else {
          left = 9 + remain * (len_per_thread + 1) +
                 (k - 9 - remain) * len_per_thread;
          right = left + real_len;
        }
        for (int j = left; j < right; j++)
          *(float*)(cur_p + (j + 1) * 4) = *(float*)(input_p + (j + 1) * 4);
Y
yaoxuefeng 已提交
125
      }
Y
yaoxuefeng 已提交
126
    } else {
Z
zmxdream 已提交
127
      if (keys[i] != 0) printf("pull miss key: %llu", keys[i]);
128 129 130
    }
  }
}
131

T
Thunderbrook 已提交
132 133
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
134
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
135
                              const typename Table::key_type* const keys,
136 137
                              const GradType* const grads,
                              size_t len,
T
Thunderbrook 已提交
138 139 140 141 142
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
143
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
144 145 146 147
    }
  }
}

148 149
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
150
                                    const OptimizerConfig& optimizer_config,
151
                                    const typename Table::key_type* const keys,
152 153 154 155
                                    const char* const grads,
                                    size_t len,
                                    Sgd sgd,
                                    size_t grad_value_size) {
156 157 158 159 160
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
Z
zmxdream 已提交
161
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
162
    } else {
Y
yaoxuefeng 已提交
163
      if (keys[i] != 0) {
164
        printf("warning::push miss key: %llu", keys[i]);
Y
yaoxuefeng 已提交
165
      }
166 167 168 169
    }
  }
}

T
Thunderbrook 已提交
170 171 172
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
Z
zmxdream 已提交
173
  cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig));
174 175 176 177
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
178
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
179 180 181 182 183 184 185
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
}

Z
zmxdream 已提交
186 187 188 189
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
190 191 192 193
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
194 195 196 197 198 199
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
200 201 202 203
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
204 205
}

T
Thunderbrook 已提交
206 207 208 209 210 211
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
212
template <typename StreamType>
213 214 215 216
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      ValType* d_vals,
                                      size_t len,
                                      StreamType stream) {
T
Thunderbrook 已提交
217 218 219 220
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
221 222
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
223 224
}

225
template <typename KeyType, typename ValType>
226
template <typename StreamType>
227 228 229 230
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      char* d_vals,
                                      size_t len,
                                      StreamType stream) {
231 232 233
  if (len == 0) {
    return;
  }
Z
zmxdream 已提交
234 235 236 237
  dim3 block_dims(32, 32);
  const int grid_size = (len - 1) / 32 + 1;
  dim3 grid_dims(grid_size);
  dy_mf_search_kernel<<<grid_dims, block_dims, 0, stream>>>(
238 239 240
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
241
template <typename KeyType, typename ValType>
242
template <typename StreamType>
T
Thunderbrook 已提交
243
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
244 245
                                         const ValType* d_vals,
                                         size_t len,
246
                                         StreamType stream) {
T
Thunderbrook 已提交
247 248 249 250
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
251 252
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
253 254
}

255
template <typename KeyType, typename ValType>
256
template <typename StreamType>
257 258 259 260
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         size_t len,
                                         char* pool,
                                         size_t feature_value_size,
Y
yaoxuefeng 已提交
261
                                         size_t start_index,
262
                                         StreamType stream) {
263 264 265 266 267 268
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
269
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
270 271
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
272 273
}

T
Thunderbrook 已提交
274
template <typename KeyType, typename ValType>
275 276
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
277
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
278
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
279 280 281
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
282 283 284 285 286 287 288 289 290 291 292 293

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
294
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
313
      }
T
Thunderbrook 已提交
314 315
#endif
#ifdef PADDLE_WITH_PSCORE
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
      auto* downpour_value =
          (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[2] = gpu_val.delta_score;
      cpu_val[3] = gpu_val.show;
      cpu_val[4] = gpu_val.clk;
      cpu_val[5] = gpu_val.lr;
      cpu_val[6] = gpu_val.lr_g2sum;
      cpu_val[0] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
334
      }
T
Thunderbrook 已提交
335
#endif
T
Thunderbrook 已提交
336 337 338 339 340 341 342 343 344 345
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
346 347
  }

T
Thunderbrook 已提交
348
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
349 350
}

T
Thunderbrook 已提交
351
template <typename KeyType, typename ValType>
352
template <typename GradType, typename Sgd, typename StreamType>
T
Thunderbrook 已提交
353
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
354 355 356 357
                                         const GradType* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
T
Thunderbrook 已提交
358 359 360 361
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
362 363
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
364 365
}

366
template <typename KeyType, typename ValType>
367
template <typename Sgd, typename StreamType>
368
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
369 370 371 372
                                         const char* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
373 374 375 376 377
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
378 379 380 381 382 383
      container_,
      *device_optimizer_config_,
      d_keys,
      d_grads,
      len,
      sgd,
Z
zmxdream 已提交
384
      push_grad_value_size_);
385 386
}

387
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
Y
yaoxuefeng 已提交
388
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
S
seemingwang 已提交
389
template class HashTable<long, int>;
T
Thunderbrook 已提交
390 391
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
392 393
template class HashTable<unsigned long, long>;
template class HashTable<unsigned long, long*>;
S
seemingwang 已提交
394
template class HashTable<long, long>;
395 396
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
397

398 399 400 401 402
template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    cudaStream_t>(const unsigned long* d_keys,
                  paddle::framework::FeatureValue* d_vals,
                  size_t len,
                  cudaStream_t stream);
403

Y
yaoxuefeng 已提交
404 405 406 407
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);

S
seemingwang 已提交
408
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
409 410
                                                      int* d_vals,
                                                      size_t len,
S
seemingwang 已提交
411 412
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
413 414
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
415 416
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
417
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
418 419
                                                       long* d_vals,
                                                       size_t len,
S
seemingwang 已提交
420
                                                       cudaStream_t stream);
421 422
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
423 424
template void HashTable<unsigned long, long>::get<cudaStream_t>(
    const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream);
425 426 427 428 429
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

430 431 432 433 434
template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeatureValue* d_vals,
                  size_t len,
                  cudaStream_t stream);
435

Y
yaoxuefeng 已提交
436
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
437 438 439 440 441
    insert<cudaStream_t>(const unsigned long* d_keys,
                         size_t len,
                         char* pool,
                         size_t feature_value_size,
                         size_t start_index,
Y
yaoxuefeng 已提交
442 443
                         cudaStream_t stream);

S
seemingwang 已提交
444 445 446 447
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
448 449 450 451
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
452

T
Thunderbrook 已提交
453
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
454 455 456
    const unsigned long* d_keys,
    const int* d_vals,
    size_t len,
T
Thunderbrook 已提交
457
    cudaStream_t stream);
458
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
459 460 461
    const long* d_keys,
    const unsigned long* d_vals,
    size_t len,
462 463 464
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
465 466 467
    const long* d_keys,
    const unsigned int* d_vals,
    size_t len,
468 469
    cudaStream_t stream);

470
template void HashTable<unsigned long, long>::insert<cudaStream_t>(
471 472 473
    const unsigned long* d_keys,
    const long* d_vals,
    size_t len,
474
    cudaStream_t stream);
475 476 477 478 479 480 481 482 483 484

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue,
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeaturePushValue* d_grads,
485
                  size_t len,
Y
yaoxuefeng 已提交
486
                  Optimizer<paddle::framework::FeatureValue,
487
                            paddle::framework::FeaturePushValue> sgd,
Y
yaoxuefeng 已提交
488 489
                  cudaStream_t stream);

490 491 492
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
    update<Optimizer<paddle::framework::FeatureValue,
                     paddle::framework::FeaturePushValue>,
493 494
           cudaStream_t>(const unsigned long* d_keys,
                         const char* d_grads,
495 496
                         size_t len,
                         Optimizer<paddle::framework::FeatureValue,
497
                                   paddle::framework::FeaturePushValue> sgd,
498 499
                         cudaStream_t stream);

500 501 502 503 504 505 506 507 508 509 510 511
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
512 513 514
}  // end namespace framework
}  // end namespace paddle
#endif