hashtable_kernel.cu 15.6 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16 17 18
#include <thread>
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
19 20 21 22

namespace paddle {
namespace framework {

23 24
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              size_t len, char* pool, int start_index) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
    kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
                                    char* const vals, size_t len,
                                    size_t pull_feature_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);

    if (it != table->end()) {
      *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
    }
  }
}
95

T
Thunderbrook 已提交
96 97
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
Z
zmxdream 已提交
98
                              const OptimizerConfig& optimizer_config,
T
Thunderbrook 已提交
99 100 101 102 103 104 105
                              const typename Table::key_type* const keys,
                              const GradType* const grads, size_t len,
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
Z
zmxdream 已提交
106
      sgd.update_value(optimizer_config, (it.getter())->second, grads[i]);
T
Thunderbrook 已提交
107 108 109 110
    }
  }
}

111 112
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
Z
zmxdream 已提交
113
                                    const OptimizerConfig& optimizer_config,
114 115 116 117 118 119 120 121
                                    const typename Table::key_type* const keys,
                                    const char* const grads, size_t len,
                                    Sgd sgd, size_t grad_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
Z
zmxdream 已提交
122
      sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
123 124 125 126 127 128
    } else {
      printf("yxf::push miss key: %d", keys[i]);
    }
  }
}

T
Thunderbrook 已提交
129 130 131
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
Z
zmxdream 已提交
132 133 134
  cudaMalloc((void**)&device_optimizer_config_, sizeof(OptimizerConfig));
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
135
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
136 137 138 139 140 141 142
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
}

Z
zmxdream 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
  cudaMemcpy((void*)device_optimizer_config_, &host_optimizer_config_,
             sizeof(OptimizerConfig), cudaMemcpyHostToDevice);
}

T
Thunderbrook 已提交
159 160 161 162 163 164
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
165
template <typename StreamType>
T
Thunderbrook 已提交
166
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
167
                                      size_t len, StreamType stream) {
T
Thunderbrook 已提交
168 169 170 171 172 173 174 175
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

176
template <typename KeyType, typename ValType>
177
template <typename StreamType>
178
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
179
                                      size_t len, StreamType stream) {
180 181 182 183 184 185 186 187
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
188
template <typename KeyType, typename ValType>
189
template <typename StreamType>
T
Thunderbrook 已提交
190 191
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
192
                                         StreamType stream) {
T
Thunderbrook 已提交
193 194 195 196 197 198 199 200
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

201
template <typename KeyType, typename ValType>
202
template <typename StreamType>
203 204
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
                                         char* pool, size_t start_index,
205
                                         StreamType stream) {
206 207 208 209 210 211
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
212
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
213 214 215 216
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
                                                       pool, start_index);
}

T
Thunderbrook 已提交
217
template <typename KeyType, typename ValType>
218 219
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
220
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
221
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
222 223 224
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
225 226 227 228 229 230 231 232 233 234 235 236

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
237
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
256
      }
T
Thunderbrook 已提交
257 258
#endif
#ifdef PADDLE_WITH_PSCORE
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
      auto* downpour_value =
          (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[2] = gpu_val.delta_score;
      cpu_val[3] = gpu_val.show;
      cpu_val[4] = gpu_val.clk;
      cpu_val[5] = gpu_val.lr;
      cpu_val[6] = gpu_val.lr_g2sum;
      cpu_val[0] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
277
      }
T
Thunderbrook 已提交
278
#endif
T
Thunderbrook 已提交
279 280 281 282 283 284 285 286 287 288
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
289 290
  }

T
Thunderbrook 已提交
291
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
292 293
}

T
Thunderbrook 已提交
294
template <typename KeyType, typename ValType>
295
template <typename GradType, typename Sgd, typename StreamType>
T
Thunderbrook 已提交
296 297
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
298
                                         Sgd sgd, StreamType stream) {
T
Thunderbrook 已提交
299 300 301 302
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
Z
zmxdream 已提交
303 304
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd);
T
Thunderbrook 已提交
305 306
}

307
template <typename KeyType, typename ValType>
308
template <typename Sgd, typename StreamType>
309 310
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
311
                                         Sgd sgd, StreamType stream) {
312 313 314 315 316
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
Z
zmxdream 已提交
317 318
      container_, *device_optimizer_config_, d_keys, d_grads, len, sgd,
      push_grad_value_size_);
319 320
}

321
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
S
seemingwang 已提交
322
template class HashTable<long, int>;
T
Thunderbrook 已提交
323 324
template class HashTable<unsigned long, int>;
template class HashTable<unsigned long, unsigned long>;
S
seemingwang 已提交
325
template class HashTable<long, long>;
326 327
template class HashTable<long, unsigned long>;
template class HashTable<long, unsigned int>;
328 329 330 331 332 333

template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    cudaStream_t>(const unsigned long* d_keys,
                  paddle::framework::FeatureValue* d_vals, size_t len,
                  cudaStream_t stream);

S
seemingwang 已提交
334 335 336 337
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
                                                      int* d_vals, size_t len,
                                                      cudaStream_t stream);

T
Thunderbrook 已提交
338 339
template void HashTable<unsigned long, int>::get<cudaStream_t>(
    const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream);
340 341
template void HashTable<long, unsigned long>::get<cudaStream_t>(
    const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream);
S
seemingwang 已提交
342 343 344
template void HashTable<long, long>::get<cudaStream_t>(const long* d_keys,
                                                       long* d_vals, size_t len,
                                                       cudaStream_t stream);
345 346
template void HashTable<long, unsigned int>::get<cudaStream_t>(
    const long* d_keys, unsigned int* d_vals, size_t len, cudaStream_t stream);
347 348 349 350 351 352 353 354 355 356
// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeatureValue* d_vals, size_t len,
                  cudaStream_t stream);

S
seemingwang 已提交
357 358 359 360
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
                                                         const int* d_vals,
                                                         size_t len,
                                                         cudaStream_t stream);
S
seemingwang 已提交
361 362 363 364
template void HashTable<long, long>::insert<cudaStream_t>(const long* d_keys,
                                                          const long* d_vals,
                                                          size_t len,
                                                          cudaStream_t stream);
S
seemingwang 已提交
365

T
Thunderbrook 已提交
366 367 368
template void HashTable<unsigned long, int>::insert<cudaStream_t>(
    const unsigned long* d_keys, const int* d_vals, size_t len,
    cudaStream_t stream);
369 370 371 372 373 374 375 376
template void HashTable<long, unsigned long>::insert<cudaStream_t>(
    const long* d_keys, const unsigned long* d_vals, size_t len,
    cudaStream_t stream);

template void HashTable<long, unsigned int>::insert<cudaStream_t>(
    const long* d_keys, const unsigned int* d_vals, size_t len,
    cudaStream_t stream);

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
//    cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
//                  size_t start_index, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue,
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeaturePushValue* d_grads,
                  size_t len, Optimizer<paddle::framework::FeatureValue,
                                        paddle::framework::FeaturePushValue>
                                  sgd,
                  cudaStream_t stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
408 409 410
}  // end namespace framework
}  // end namespace paddle
#endif