hashtable_kernel.cu 17.9 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16 17 18
#include <thread>
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
19 20 21 22

namespace paddle {
namespace framework {

23 24
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

50 51 52
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
Y
yaoxuefeng 已提交
53 54
                              size_t len, char* pool, size_t feature_value_size,
                              int start_index) {
55 56 57 58 59 60 61
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
Y
yaoxuefeng 已提交
62 63
    uint64_t offset = uint64_t(start_index + i) * feature_value_size;
    kv.second = (Table::mapped_type)(pool + offset);
64 65 66 67 68
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

83 84 85
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
Y
yaoxuefeng 已提交
86
                                    char* vals, size_t len,
87 88
                                    size_t pull_feature_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
Y
yaoxuefeng 已提交
89
  // return;
90 91 92 93
  if (i < len) {
    auto it = table->find(keys[i]);

    if (it != table->end()) {
Y
yaoxuefeng 已提交
94
      uint64_t offset = i * pull_feature_value_size;
Y
yaoxuefeng 已提交
95
      FeatureValue* cur = (FeatureValue*)(vals + offset);
Y
yaoxuefeng 已提交
96
      FeatureValue& input = *(FeatureValue*)(it->second);
Y
yaoxuefeng 已提交
97 98 99 100 101 102 103 104 105 106 107 108
      cur->slot = input.slot;
      cur->show = input.show;
      cur->clk = input.clk;
      cur->mf_dim = input.mf_dim;
      cur->lr = input.lr;
      cur->mf_size = input.mf_size;
      cur->cpu_ptr = input.cpu_ptr;
      cur->delta_score = input.delta_score;
      cur->lr_g2sum = input.lr_g2sum;
      for (int j = 0; j < cur->mf_dim + 1; ++j) {
        cur->mf[j] = input.mf[j];
      }
Y
yaoxuefeng 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    } else {
      if (keys[i] != 0) {
        printf("warning::pull miss key: %d", keys[i]);
      }
      FeatureValue* cur = (FeatureValue*)(vals + i * pull_feature_value_size);
      cur->delta_score = 0;
      cur->show = 0;
      cur->clk = 0;
      cur->slot = -1;
      cur->lr = 0;
      cur->lr_g2sum = 0;
      cur->mf_size = 0;
      cur->mf_dim = 8;
      cur->cpu_ptr;
      for (int j = 0; j < cur->mf_dim + 1; j++) {
        cur->mf[j] = 0;
      }
126 127 128
    }
  }
}
129

T
Thunderbrook 已提交
130 131
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
132
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
133 134 135 136 137 138 139
                              const typename Table::key_type* const keys,
                              const GradType* const grads, size_t len,
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
140
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
141 142 143 144
    }
  }
}

145 146
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
147
                                    const OptimizerConfig& optimizer_config,
148 149 150 151 152 153 154 155
                                    const typename Table::key_type* const keys,
                                    const char* const grads, size_t len,
                                    Sgd sgd, size_t grad_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
Z
zmxdream 已提交
156
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
157
    } else {
Y
yaoxuefeng 已提交
158 159 160
      if (keys[i] != 0) {
        printf("warning::push miss key: %d", keys[i]);
      }
161 162 163 164
    }
  }
}

T
Thunderbrook 已提交
165 166 167
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
Z
zmxdream 已提交
168 169 170
  cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig));
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
171
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
172 173 174 175 176 177 178
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
}

Z
zmxdream 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
}

T
Thunderbrook 已提交
195 196 197 198 199 200
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
201
template <typename StreamType>
T
Thunderbrook 已提交
202
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
203
                                      size_t len, StreamType stream) {
T
Thunderbrook 已提交
204 205 206 207 208 209 210 211
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

212
template <typename KeyType, typename ValType>
213
template <typename StreamType>
214
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
215
                                      size_t len, StreamType stream) {
216 217 218 219 220 221 222 223
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
224
template <typename KeyType, typename ValType>
225
template <typename StreamType>
T
Thunderbrook 已提交
226 227
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
228
                                         StreamType stream) {
T
Thunderbrook 已提交
229 230 231 232 233 234 235 236
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

237
template <typename KeyType, typename ValType>
238
template <typename StreamType>
239
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
Y
yaoxuefeng 已提交
240 241
                                         char* pool, size_t feature_value_size,
                                         size_t start_index,
242
                                         StreamType stream) {
243 244 245 246 247 248
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
249
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Y
yaoxuefeng 已提交
250 251
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, len, pool, feature_value_size, start_index);
252 253
}

T
Thunderbrook 已提交
254
template <typename KeyType, typename ValType>
255 256
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
257
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
258
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
259 260 261
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
262 263 264 265 266 267 268 269 270 271 272 273

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
274
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
293
      }
T
Thunderbrook 已提交
294 295
#endif
#ifdef PADDLE_WITH_PSCORE
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
      auto* downpour_value =
          (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[2] = gpu_val.delta_score;
      cpu_val[3] = gpu_val.show;
      cpu_val[4] = gpu_val.clk;
      cpu_val[5] = gpu_val.lr;
      cpu_val[6] = gpu_val.lr_g2sum;
      cpu_val[0] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
314
      }
T
Thunderbrook 已提交
315
#endif
T
Thunderbrook 已提交
316 317 318 319 320 321 322 323 324 325
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
326 327
  }

T
Thunderbrook 已提交
328
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
329 330
}

T
Thunderbrook 已提交
331
template <typename KeyType, typename ValType>
332
template <typename GradType, typename Sgd, typename StreamType>
T
Thunderbrook 已提交
333 334
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
335
                                         Sgd sgd, StreamType stream) {
T
Thunderbrook 已提交
336 337 338 339
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
340 341
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
342 343
}

344
template <typename KeyType, typename ValType>
345
template <typename Sgd, typename StreamType>
346 347
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
348
                                         Sgd sgd, StreamType stream) {
349 350 351 352 353
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
Z
zmxdream 已提交
354 355
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd,
      push_grad_value_size_);
356 357
}

358
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
Y
yaoxuefeng 已提交
359
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
S
seemingwang 已提交
360
template class HashTable<long, int>;
T
Thunderbrook 已提交
361 362
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
363 364
template class HashTable<unsigned long, long>;
template class HashTable<unsigned long, long*>;
S
seemingwang 已提交
365
template class HashTable<long, long>;
366 367
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
368 369 370 371 372 373

template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    cudaStream_t>(const unsigned long* d_keys,
                  paddle::framework::FeatureValue* d_vals, size_t len,
                  cudaStream_t stream);

Y
yaoxuefeng 已提交
374 375 376 377
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);

S
seemingwang 已提交
378 379 380 381
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
                                                      int* d_vals, size_t len,
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
382 383
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
384 385
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
386 387 388
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
                                                       long* d_vals, size_t len,
                                                       cudaStream_t stream);
389 390
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
391 392
template void HashTable<unsigned long, long>::get<cudaStream_t>(
    const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream);
393 394 395 396 397 398 399 400 401 402
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeatureValue* d_vals, size_t len,
                  cudaStream_t stream);

Y
yaoxuefeng 已提交
403 404 405 406 407
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
    insert<cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
                         size_t feature_value_size, size_t start_index,
                         cudaStream_t stream);

S
seemingwang 已提交
408 409 410 411
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
412 413 414 415
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
416

T
Thunderbrook 已提交
417 418 419
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
    const unsigned long* d_keys, const int* d_vals, size_t len,
    cudaStream_t stream);
420 421 422 423 424 425 426 427
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
    const long* d_keys, const unsigned long* d_vals, size_t len,
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
    const long* d_keys, const unsigned int* d_vals, size_t len,
    cudaStream_t stream);

428 429 430
template void HashTable<unsigned long, long>::insert<cudaStream_t>(
    const unsigned long* d_keys, const long* d_vals, size_t len,
    cudaStream_t stream);
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue,
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeaturePushValue* d_grads,
                  size_t len, Optimizer<paddle::framework::FeatureValue,
                                        paddle::framework::FeaturePushValue>
                                  sgd,
                  cudaStream_t stream);

Y
yaoxuefeng 已提交
446 447 448 449 450 451 452 453 454 455
template void
HashTable<unsigned long, paddle::framework::FeatureValue*>::update<
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t len,
                  Optimizer<paddle::framework::FeatureValue,
                            paddle::framework::FeaturePushValue>
                      sgd,
                  cudaStream_t stream);

456 457 458 459 460 461 462 463 464 465 466 467
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
468 469 470
}  // end namespace framework
}  // end namespace paddle
#endif