hashtable_kernel.cu 17.5 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16
#include <thread>
17

18 19
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

51 52 53
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
54 55 56
                              size_t len,
                              char* pool,
                              size_t feature_value_size,
Y
yaoxuefeng 已提交
57
                              int start_index) {
58 59 60 61 62 63 64
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
65 66
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
67 68 69 70 71
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

D
danleifeng 已提交
86
template <typename Table, typename GPUAccessor>
87 88
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
89 90
                                    char* vals,
                                    size_t len,
D
danleifeng 已提交
91
                                    size_t pull_feature_value_size,
D
danleifeng 已提交
92
                                    GPUAccessor gpu_accessor) {
93
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
D
danleifeng 已提交
94
  // return;
95 96 97
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Y
yaoxuefeng 已提交
98
      uint64_t offset = i * pull_feature_value_size;
D
danleifeng 已提交
99 100
      float* cur = (float*)(vals + offset);
      float* input = it->second;
D
danleifeng 已提交
101
      gpu_accessor.PullValueFill(cur, input);
102 103 104
    }
  }
}
105

T
Thunderbrook 已提交
106 107
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
108
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
109
                              const typename Table::key_type* const keys,
110 111
                              const GradType* const grads,
                              size_t len,
T
Thunderbrook 已提交
112 113 114 115 116
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
117
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
118 119 120 121
    }
  }
}

122 123
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
124
                                    const OptimizerConfig& optimizer_config,
125
                                    const typename Table::key_type* const keys,
126 127 128 129
                                    const char* const grads,
                                    size_t len,
                                    Sgd sgd,
                                    size_t grad_value_size) {
130 131 132 133
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
D
danleifeng 已提交
134 135
      float* cur = (float*)(grads + i * grad_value_size);
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, cur);
136
    } else {
D
danleifeng 已提交
137
      printf("warning: push miss key: %lu", keys[i]);
138 139 140 141
    }
  }
}

T
Thunderbrook 已提交
142 143 144
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
D
danleifeng 已提交
145 146 147 148 149 150
  CUDA_RT_CALL(
      cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig)));
  CUDA_RT_CALL(cudaMemcpy((void*)device_optimizer_config_,
                          &host_optimizer_config_,
                          sizeof(OptimizerConfig),
                          cudaMemcpyHostToDevice));
151
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
152 153 154 155 156
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
Z
zmxdream 已提交
157
  cudaFree(device_optimizer_config_);
T
Thunderbrook 已提交
158 159
}

Z
zmxdream 已提交
160 161 162 163
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
164 165 166 167
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
168 169 170 171 172 173
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
174 175 176 177
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
178 179
}

T
Thunderbrook 已提交
180 181 182 183 184 185
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
186
template <typename StreamType>
187 188 189 190
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      ValType* d_vals,
                                      size_t len,
                                      StreamType stream) {
T
Thunderbrook 已提交
191 192 193 194
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
195 196
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
197 198
}

199
template <typename KeyType, typename ValType>
D
danleifeng 已提交
200
template <typename StreamType, typename GPUAccessor>
201 202 203
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      char* d_vals,
                                      size_t len,
D
danleifeng 已提交
204
                                      StreamType stream,
D
danleifeng 已提交
205
                                      GPUAccessor& fv_accessor) {
206 207 208
  if (len == 0) {
    return;
  }
209 210
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
D
danleifeng 已提交
211
      container_, d_keys, d_vals, len, pull_feature_value_size_, fv_accessor);
212 213
}

T
Thunderbrook 已提交
214
template <typename KeyType, typename ValType>
215
template <typename StreamType>
T
Thunderbrook 已提交
216
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
217 218
                                         const ValType* d_vals,
                                         size_t len,
219
                                         StreamType stream) {
T
Thunderbrook 已提交
220 221 222 223
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
224 225
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
226 227
}

228
template <typename KeyType, typename ValType>
229
template <typename StreamType>
230 231 232 233
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         size_t len,
                                         char* pool,
                                         size_t feature_value_size,
Y
yaoxuefeng 已提交
234
                                         size_t start_index,
235
                                         StreamType stream) {
236 237 238 239 240 241
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
242
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
243 244
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
245 246
}

T
Thunderbrook 已提交
247
template <typename KeyType, typename ValType>
248 249
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
250
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
251
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
252 253 254
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
255 256 257 258 259 260 261 262 263 264 265 266

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
267
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
286
      }
T
Thunderbrook 已提交
287
#endif
T
Thunderbrook 已提交
288 289 290 291 292 293 294 295 296 297
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
298 299
  }

T
Thunderbrook 已提交
300
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
301 302
}

T
Thunderbrook 已提交
303
template <typename KeyType, typename ValType>
D
danleifeng 已提交
304
template <typename Sgd, typename StreamType>
T
Thunderbrook 已提交
305
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
D
danleifeng 已提交
306
                                         const float* d_grads,
307 308 309
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
T
Thunderbrook 已提交
310 311 312 313
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
314 315
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
316 317
}

318
template <typename KeyType, typename ValType>
319
template <typename Sgd, typename StreamType>
320
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
321 322 323 324
                                         const char* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
325 326 327 328 329
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
330 331 332 333 334 335
      container_,
      *device_optimizer_config_,
      d_keys,
      d_grads,
      len,
      sgd,
Z
zmxdream 已提交
336
      push_grad_value_size_);
337 338
}

D
danleifeng 已提交
339 340
template class HashTable<unsigned long, float>;
template class HashTable<unsigned long, float*>;
S
seemingwang 已提交
341
template class HashTable<long, int>;
T
Thunderbrook 已提交
342 343
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
D
danleifeng 已提交
344
template class HashTable<unsigned long, unsigned long*>;
345 346
template class HashTable<unsigned long, long>;
template class HashTable<unsigned long, long*>;
S
seemingwang 已提交
347
template class HashTable<long, long>;
348 349
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
350

D
danleifeng 已提交
351 352 353 354 355
template void HashTable<unsigned long, float>::get<cudaStream_t>(
    const unsigned long* d_keys,
    float* d_vals,
    size_t len,
    cudaStream_t stream);
356

Y
yaoxuefeng 已提交
357
template void
D
danleifeng 已提交
358 359 360 361 362 363
HashTable<unsigned long, float*>::get<cudaStream_t, CommonFeatureValueAccessor>(
    const unsigned long* d_keys,
    char* d_vals,
    size_t len,
    cudaStream_t stream,
    CommonFeatureValueAccessor& fv_accessor);
Y
yaoxuefeng 已提交
364

S
seemingwang 已提交
365
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
366 367
                                                      int* d_vals,
                                                      size_t len,
S
seemingwang 已提交
368 369
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
370 371
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
D
danleifeng 已提交
372 373 374 375 376
template void HashTable<unsigned long, unsigned long>::get<cudaStream_t>(
    const unsigned long* d_keys,
    unsigned long* d_vals,
    size_t len,
    cudaStream_t stream);
D
danleifeng 已提交
377 378
template void HashTable<unsigned long, long>::get<cudaStream_t>(
    const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream);
379 380
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
381
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
382 383
                                                       long* d_vals,
                                                       size_t len,
S
seemingwang 已提交
384
                                                       cudaStream_t stream);
385 386
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
387 388 389 390 391
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

D
danleifeng 已提交
392 393 394 395 396
template void HashTable<unsigned long, float>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    const float* d_vals,
    size_t len,
    cudaStream_t stream);
397

D
danleifeng 已提交
398 399 400 401 402 403 404
template void HashTable<unsigned long, float*>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    size_t len,
    char* pool,
    size_t feature_value_size,
    size_t start_index,
    cudaStream_t stream);
Y
yaoxuefeng 已提交
405

S
seemingwang 已提交
406 407 408 409
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
410 411 412 413
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
414

T
Thunderbrook 已提交
415
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
416 417 418
    const unsigned long* d_keys,
    const int* d_vals,
    size_t len,
T
Thunderbrook 已提交
419
    cudaStream_t stream);
D
danleifeng 已提交
420 421 422 423 424 425 426

template void HashTable<unsigned long, long>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    const long* d_vals,
    size_t len,
    cudaStream_t stream);

427
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
428 429 430
    const long* d_keys,
    const unsigned long* d_vals,
    size_t len,
431 432 433
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
434 435 436
    const long* d_keys,
    const unsigned int* d_vals,
    size_t len,
437 438
    cudaStream_t stream);

D
danleifeng 已提交
439 440 441 442 443
template void HashTable<unsigned long, unsigned long>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    const unsigned long* d_vals,
    size_t len,
    cudaStream_t stream);
444

D
danleifeng 已提交
445 446 447 448
template void HashTable<unsigned long, float*>::dump_to_cpu<cudaStream_t>(
    int devid, cudaStream_t stream);

template void HashTable<unsigned long, float*>::update<
D
danleifeng 已提交
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    SparseAdagradOptimizer<CommonFeatureValueAccessor>,
    cudaStream_t>(const unsigned long* d_keys,
                  const char* d_grads,
                  size_t len,
                  SparseAdagradOptimizer<CommonFeatureValueAccessor> sgd,
                  cudaStream_t stream);
template void HashTable<unsigned long, float*>::update<
    SparseAdamOptimizer<CommonFeatureValueAccessor>,
    cudaStream_t>(const unsigned long* d_keys,
                  const char* d_grads,
                  size_t len,
                  SparseAdamOptimizer<CommonFeatureValueAccessor> sgd,
                  cudaStream_t stream);
template void HashTable<unsigned long, float*>::update<
    SparseAdamSharedOptimizer<CommonFeatureValueAccessor>,
464
    cudaStream_t>(const unsigned long* d_keys,
D
danleifeng 已提交
465
                  const char* d_grads,
466
                  size_t len,
D
danleifeng 已提交
467
                  SparseAdamSharedOptimizer<CommonFeatureValueAccessor> sgd,
Y
yaoxuefeng 已提交
468 469
                  cudaStream_t stream);

470 471 472 473 474 475 476 477 478 479 480 481
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
482 483 484
}  // end namespace framework
}  // end namespace paddle
#endif