hashtable_kernel.cu 17.2 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16
#include <thread>
17

18 19
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

51 52 53
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
54 55 56
                              size_t len,
                              char* pool,
                              size_t feature_value_size,
Y
yaoxuefeng 已提交
57
                              int start_index) {
58 59 60 61 62 63 64
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
65 66
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
67 68 69 70 71
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

D
danleifeng 已提交
86
template <typename Table, typename FVAccessor>
87 88
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
89 90
                                    char* vals,
                                    size_t len,
D
danleifeng 已提交
91 92
                                    size_t pull_feature_value_size,
                                    FVAccessor feature_value_accessor) {
93
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
94 95
  if (i < len) {
    auto it = table->find(keys[i]);
96

97
    if (it != table->end()) {
Y
yaoxuefeng 已提交
98
      uint64_t offset = i * pull_feature_value_size;
D
danleifeng 已提交
99 100 101 102 103 104
      float* cur = (float*)(vals + offset);
      float* input = it->second;
      int mf_dim =
          int(input[feature_value_accessor.common_feature_value.MfDimIndex()]);

      feature_value_accessor.FeatureValueFill(cur, input, mf_dim);
105 106 107
    }
  }
}
108

T
Thunderbrook 已提交
109 110
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
111
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
112
                              const typename Table::key_type* const keys,
113 114
                              const GradType* const grads,
                              size_t len,
T
Thunderbrook 已提交
115 116 117 118 119
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
120
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
121 122 123 124
    }
  }
}

125 126
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
127
                                    const OptimizerConfig& optimizer_config,
128
                                    const typename Table::key_type* const keys,
129 130 131 132
                                    const char* const grads,
                                    size_t len,
                                    Sgd sgd,
                                    size_t grad_value_size) {
133 134 135 136
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
D
danleifeng 已提交
137 138
      float* cur = (float*)(grads + i * grad_value_size);
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, cur);
139
    } else {
Y
yaoxuefeng 已提交
140
      if (keys[i] != 0) {
141
        printf("warning::push miss key: %llu", keys[i]);
Y
yaoxuefeng 已提交
142
      }
143 144 145 146
    }
  }
}

T
Thunderbrook 已提交
147 148 149
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
Z
zmxdream 已提交
150
  cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig));
151 152 153 154
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
155
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
156 157 158 159 160
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
Z
zmxdream 已提交
161
  cudaFree(device_optimizer_config_);
T
Thunderbrook 已提交
162 163
}

Z
zmxdream 已提交
164 165 166 167
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
168 169 170 171
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
172 173 174 175 176 177
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
178 179 180 181
  cudaMemcpy((void*)device_optimizer_config_,
             &host_optimizer_config_,
             sizeof(OptimizerConfig),
             cudaMemcpyHostToDevice);
Z
zmxdream 已提交
182 183
}

T
Thunderbrook 已提交
184 185 186 187 188 189
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
190
template <typename StreamType>
191 192 193 194
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      ValType* d_vals,
                                      size_t len,
                                      StreamType stream) {
T
Thunderbrook 已提交
195 196 197 198
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
199 200
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
201 202
}

203
template <typename KeyType, typename ValType>
D
danleifeng 已提交
204
template <typename StreamType, typename FVAccessor>
205 206 207
void HashTable<KeyType, ValType>::get(const KeyType* d_keys,
                                      char* d_vals,
                                      size_t len,
D
danleifeng 已提交
208 209
                                      StreamType stream,
                                      FVAccessor& fv_accessor) {
210 211 212
  if (len == 0) {
    return;
  }
213 214
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
D
danleifeng 已提交
215
      container_, d_keys, d_vals, len, pull_feature_value_size_, fv_accessor);
216 217
}

T
Thunderbrook 已提交
218
template <typename KeyType, typename ValType>
219
template <typename StreamType>
T
Thunderbrook 已提交
220
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
221 222
                                         const ValType* d_vals,
                                         size_t len,
223
                                         StreamType stream) {
T
Thunderbrook 已提交
224 225 226 227
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
228 229
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len);
T
Thunderbrook 已提交
230 231
}

232
template <typename KeyType, typename ValType>
233
template <typename StreamType>
234 235 236 237
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         size_t len,
                                         char* pool,
                                         size_t feature_value_size,
Y
yaoxuefeng 已提交
238
                                         size_t start_index,
239
                                         StreamType stream) {
240 241 242 243 244 245
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
246
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
247 248
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
249 250
}

T
Thunderbrook 已提交
251
template <typename KeyType, typename ValType>
252 253
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
254
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
255
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
256 257 258
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
259 260 261 262 263 264 265 266 267 268 269 270

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
271
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
290
      }
T
Thunderbrook 已提交
291
#endif
T
Thunderbrook 已提交
292 293 294 295 296 297 298 299 300 301
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
302 303
  }

T
Thunderbrook 已提交
304
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
305 306
}

T
Thunderbrook 已提交
307
template <typename KeyType, typename ValType>
D
danleifeng 已提交
308
template <typename Sgd, typename StreamType>
T
Thunderbrook 已提交
309
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
D
danleifeng 已提交
310
                                         const float* d_grads,
311 312 313
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
T
Thunderbrook 已提交
314 315 316 317
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
318 319
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
320 321
}

322
template <typename KeyType, typename ValType>
323
template <typename Sgd, typename StreamType>
324
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
325 326 327 328
                                         const char* d_grads,
                                         size_t len,
                                         Sgd sgd,
                                         StreamType stream) {
329 330 331 332 333
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
334 335 336 337 338 339
      container_,
      *device_optimizer_config_,
      d_keys,
      d_grads,
      len,
      sgd,
Z
zmxdream 已提交
340
      push_grad_value_size_);
341 342
}

D
danleifeng 已提交
343 344
template class HashTable<unsigned long, float>;
template class HashTable<unsigned long, float*>;
S
seemingwang 已提交
345
template class HashTable<long, int>;
T
Thunderbrook 已提交
346 347
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
348 349
template class HashTable<unsigned long, long>;
template class HashTable<unsigned long, long*>;
S
seemingwang 已提交
350
template class HashTable<long, long>;
351 352
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
353

D
danleifeng 已提交
354 355 356 357 358
template void HashTable<unsigned long, float>::get<cudaStream_t>(
    const unsigned long* d_keys,
    float* d_vals,
    size_t len,
    cudaStream_t stream);
359

Y
yaoxuefeng 已提交
360
template void
D
danleifeng 已提交
361 362 363 364 365 366
HashTable<unsigned long, float*>::get<cudaStream_t, CommonFeatureValueAccessor>(
    const unsigned long* d_keys,
    char* d_vals,
    size_t len,
    cudaStream_t stream,
    CommonFeatureValueAccessor& fv_accessor);
Y
yaoxuefeng 已提交
367

S
seemingwang 已提交
368
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
369 370
                                                      int* d_vals,
                                                      size_t len,
S
seemingwang 已提交
371 372
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
373 374
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
D
danleifeng 已提交
375 376 377 378 379 380
template void HashTable<unsigned long, unsigned long>::get<cudaStream_t>(
    const unsigned long* d_keys,
    unsigned long* d_vals,
    size_t len,
    cudaStream_t stream);

381 382
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
383
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
384 385
                                                       long* d_vals,
                                                       size_t len,
S
seemingwang 已提交
386
                                                       cudaStream_t stream);
387 388
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
389 390
template void HashTable<unsigned long, long>::get<cudaStream_t>(
    const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream);
391 392 393 394 395
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

D
danleifeng 已提交
396 397 398 399 400
template void HashTable<unsigned long, float>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    const float* d_vals,
    size_t len,
    cudaStream_t stream);
401

D
danleifeng 已提交
402 403 404 405 406 407 408
template void HashTable<unsigned long, float*>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    size_t len,
    char* pool,
    size_t feature_value_size,
    size_t start_index,
    cudaStream_t stream);
Y
yaoxuefeng 已提交
409

S
seemingwang 已提交
410 411 412 413
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
414 415 416 417
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
418

T
Thunderbrook 已提交
419
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
420 421 422
    const unsigned long* d_keys,
    const int* d_vals,
    size_t len,
T
Thunderbrook 已提交
423
    cudaStream_t stream);
424
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
425 426 427
    const long* d_keys,
    const unsigned long* d_vals,
    size_t len,
428 429 430
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
431 432 433
    const long* d_keys,
    const unsigned int* d_vals,
    size_t len,
434 435
    cudaStream_t stream);

436
template void HashTable<unsigned long, long>::insert<cudaStream_t>(
437 438 439
    const unsigned long* d_keys,
    const long* d_vals,
    size_t len,
440
    cudaStream_t stream);
441

D
danleifeng 已提交
442 443 444 445 446
template void HashTable<unsigned long, unsigned long>::insert<cudaStream_t>(
    const unsigned long* d_keys,
    const unsigned long* d_vals,
    size_t len,
    cudaStream_t stream);
447

D
danleifeng 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
template void HashTable<unsigned long, float*>::dump_to_cpu<cudaStream_t>(
    int devid, cudaStream_t stream);

template void
HashTable<unsigned long, float*>::update<SparseAdagradOptimizer, cudaStream_t>(
    const unsigned long* d_keys,
    const char* d_grads,
    size_t len,
    SparseAdagradOptimizer sgd,
    cudaStream_t stream);
template void
HashTable<unsigned long, float*>::update<SparseAdamOptimizer, cudaStream_t>(
    const unsigned long* d_keys,
    const char* d_grads,
    size_t len,
    SparseAdamOptimizer sgd,
    cudaStream_t stream);
template void HashTable<unsigned long, float*>::update<
    SparseAdamSharedOptimizer,
467
    cudaStream_t>(const unsigned long* d_keys,
D
danleifeng 已提交
468
                  const char* d_grads,
469
                  size_t len,
D
danleifeng 已提交
470
                  SparseAdamSharedOptimizer sgd,
Y
yaoxuefeng 已提交
471 472
                  cudaStream_t stream);

473 474 475 476 477 478 479 480 481 482 483 484
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
485 486 487
}  // end namespace framework
}  // end namespace paddle
#endif