hashtable_kernel.cu 18.5 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16
#include <thread>
17

18 19
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

51 52 53
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
54 55 56
                              size_t len,
                              char* pool,
                              size_t feature_value_size,
Y
yaoxuefeng 已提交
57
                              int start_index) {
58 59 60 61 62 63 64
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
65 66
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
67 68 69 70 71
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

86 87 88
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
89 90
                                    char* vals,
                                    size_t len,
91
                                    size_t pull_feature_value_size) {
92
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
93 94
  if (i < len) {
    auto it = table->find(keys[i]);
95

96
    if (it != table->end()) {
Y
yaoxuefeng 已提交
97
      uint64_t offset = i * pull_feature_value_size;
Y
yaoxuefeng 已提交
98
      FeatureValue* cur = (FeatureValue*)(vals + offset);
Y
yaoxuefeng 已提交
99
      FeatureValue& input = *(FeatureValue*)(it->second);
100 101 102 103 104 105 106 107 108 109 110
      cur->slot = input.slot;
      cur->show = input.show;
      cur->clk = input.clk;
      cur->mf_dim = input.mf_dim;
      cur->lr = input.lr;
      cur->mf_size = input.mf_size;
      cur->cpu_ptr = input.cpu_ptr;
      cur->delta_score = input.delta_score;
      cur->lr_g2sum = input.lr_g2sum;
      for (int j = 0; j < cur->mf_dim + 1; ++j) {
        cur->mf[j] = input.mf[j];
Y
yaoxuefeng 已提交
111
      }
Y
yaoxuefeng 已提交
112
    } else {
113 114 115
      if (keys[i] != 0) {
        printf("warning::pull miss key: %llu", keys[i]);
      }
116 117 118
    }
  }
}
119

T
Thunderbrook 已提交
120 121
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
122
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
123
                              const typename Table::key_type* const keys,
124 125
                              const GradType* const grads,
                              size_t len,
T
Thunderbrook 已提交
126 127 128 129 130
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
131
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
132 133 134 135
    }
  }
}

136 137
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
138
                                    const OptimizerConfig& optimizer_config,
139
                                    const typename Table::key_type* const keys,
140 141 142 143
                                    const char* const grads,
                                    size_t len,
                                    Sgd sgd,
                                    size_t grad_value_size) {
144 145 146 147 148
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
Z
zmxdream 已提交
149
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
150
    } else {
Y
yaoxuefeng 已提交
151
      if (keys[i] != 0) {
152
        printf("warning::push miss key: %llu", keys[i]);
Y
yaoxuefeng 已提交
153
      }
154 155 156 157
    }
  }
}

T
Thunderbrook 已提交
158 159 160
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
Z
zmxdream 已提交
161
  cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig));
162 163 164 165
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
166
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
167 168 169 170 171
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
Z
zmxdream 已提交
172
  cudaFree(device_optimizer_config_);
T
Thunderbrook 已提交
173 174
}

Z
zmxdream 已提交
175 176 177 178
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
179 180 181 182
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
183 184 185 186 187 188
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
189 190 191 192
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
193 194
}

T
Thunderbrook 已提交
195 196 197 198 199 200
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
201
template <typename StreamType>
202 203 204 205
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      ValType* d_vals,
                                      size_t len,
                                      StreamType stream) {
T
Thunderbrook 已提交
206 207 208 209
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
210 211
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
212 213
}

214
template <typename KeyType, typename ValType>
215
template <typename StreamType>
216 217 218 219
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      char* d_vals,
                                      size_t len,
                                      StreamType stream) {
220 221 222
  if (len == 0) {
    return;
  }
223 224
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
225 226 227
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
228
template <typename KeyType, typename ValType>
229
template <typename StreamType>
T
Thunderbrook 已提交
230
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
231 232
                                         const ValType* d_vals,
                                         size_t len,
233
                                         StreamType stream) {
T
Thunderbrook 已提交
234 235 236 237
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
238 239
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
240 241
}

242
template <typename KeyType, typename ValType>
243
template <typename StreamType>
244 245 246 247
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         size_t len,
                                         char* pool,
                                         size_t feature_value_size,
Y
yaoxuefeng 已提交
248
                                         size_t start_index,
249
                                         StreamType stream) {
250 251 252 253 254 255
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
256
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
257 258
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
259 260
}

T
Thunderbrook 已提交
261
template <typename KeyType, typename ValType>
262 263
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
264
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
265
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
266 267 268
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
269 270 271 272 273 274 275 276 277 278 279 280

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
281
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
300
      }
T
Thunderbrook 已提交
301 302
#endif
#ifdef PADDLE_WITH_PSCORE
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
      auto* downpour_value =
          (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[2] = gpu_val.delta_score;
      cpu_val[3] = gpu_val.show;
      cpu_val[4] = gpu_val.clk;
      cpu_val[5] = gpu_val.lr;
      cpu_val[6] = gpu_val.lr_g2sum;
      cpu_val[0] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
321
      }
T
Thunderbrook 已提交
322
#endif
T
Thunderbrook 已提交
323 324 325 326 327 328 329 330 331 332
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
333 334
  }

T
Thunderbrook 已提交
335
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
336 337
}

T
Thunderbrook 已提交
338
template <typename KeyType, typename ValType>
339
template <typename GradType, typename Sgd, typename StreamType>
T
Thunderbrook 已提交
340
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
341 342 343 344
                                         const GradType* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
T
Thunderbrook 已提交
345 346 347 348
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
349 350
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
351 352
}

353
template <typename KeyType, typename ValType>
354
template <typename Sgd, typename StreamType>
355
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
356 357 358 359
                                         const char* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
360 361 362 363 364
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
365 366 367 368 369 370
      container_,
      *device_optimizer_config_,
      d_keys,
      d_grads,
      len,
      sgd,
Z
zmxdream 已提交
371
      push_grad_value_size_);
372 373
}

374
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
Y
yaoxuefeng 已提交
375
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
S
seemingwang 已提交
376
template class HashTable<long, int>;
T
Thunderbrook 已提交
377 378
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
379 380
template class HashTable<unsigned long, long>;
template class HashTable<unsigned long, long*>;
S
seemingwang 已提交
381
template class HashTable<long, long>;
382 383
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
384

385 386 387 388 389
template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    cudaStream_t>(const unsigned long* d_keys,
                  paddle::framework::FeatureValue* d_vals,
                  size_t len,
                  cudaStream_t stream);
390

Y
yaoxuefeng 已提交
391 392 393 394
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);

S
seemingwang 已提交
395
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
396 397
                                                      int* d_vals,
                                                      size_t len,
S
seemingwang 已提交
398 399
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
400 401
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
402 403
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
404
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
405 406
                                                       long* d_vals,
                                                       size_t len,
S
seemingwang 已提交
407
                                                       cudaStream_t stream);
408 409
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
410 411
template void HashTable<unsigned long, long>::get<cudaStream_t>(
    const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream);
412 413 414 415 416
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

417 418 419 420 421
template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeatureValue* d_vals,
                  size_t len,
                  cudaStream_t stream);
422

Y
yaoxuefeng 已提交
423
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
424 425 426 427 428
    insert<cudaStream_t>(const unsigned long* d_keys,
                         size_t len,
                         char* pool,
                         size_t feature_value_size,
                         size_t start_index,
Y
yaoxuefeng 已提交
429 430
                         cudaStream_t stream);

S
seemingwang 已提交
431 432 433 434
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
435 436 437 438
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
439

T
Thunderbrook 已提交
440
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
441 442 443
    const unsigned long* d_keys,
    const int* d_vals,
    size_t len,
T
Thunderbrook 已提交
444
    cudaStream_t stream);
445
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
446 447 448
    const long* d_keys,
    const unsigned long* d_vals,
    size_t len,
449 450 451
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
452 453 454
    const long* d_keys,
    const unsigned int* d_vals,
    size_t len,
455 456
    cudaStream_t stream);

457
template void HashTable<unsigned long, long>::insert<cudaStream_t>(
458 459 460
    const unsigned long* d_keys,
    const long* d_vals,
    size_t len,
461
    cudaStream_t stream);
462 463 464 465 466 467 468 469 470 471

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue,
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeaturePushValue* d_grads,
472
                  size_t len,
Y
yaoxuefeng 已提交
473
                  Optimizer<paddle::framework::FeatureValue,
474
                            paddle::framework::FeaturePushValue> sgd,
Y
yaoxuefeng 已提交
475 476
                  cudaStream_t stream);

477 478 479
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
    update<Optimizer<paddle::framework::FeatureValue,
                     paddle::framework::FeaturePushValue>,
480 481
           cudaStream_t>(const unsigned long* d_keys,
                         const char* d_grads,
482 483
                         size_t len,
                         Optimizer<paddle::framework::FeatureValue,
484
                                   paddle::framework::FeaturePushValue> sgd,
485 486
                         cudaStream_t stream);

487 488 489 490 491 492 493 494 495 496 497 498
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
499 500 501
}  // end namespace framework
}  // end namespace paddle
#endif