hashtable_kernel.cu 12.4 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
T
Thunderbrook 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
16 17 18
#include <thread>
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
T
Thunderbrook 已提交
19 20 21 22

namespace paddle {
namespace framework {

23 24
#if defined(PADDLE_WITH_CUDA)

T
Thunderbrook 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              size_t len, char* pool, int start_index) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
    kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
                                    char* const vals, size_t len,
                                    size_t pull_feature_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);

    if (it != table->end()) {
      *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
    }
  }
}
95

T
Thunderbrook 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const GradType* const grads, size_t len,
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      sgd.update_value((it.getter())->second, grads[i]);
    }
  }
}

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
                                    const typename Table::key_type* const keys,
                                    const char* const grads, size_t len,
                                    Sgd sgd, size_t grad_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
      sgd.dy_mf_update_value((it.getter())->second, *cur);
    } else {
      printf("yxf::push miss key: %d", keys[i]);
    }
  }
}

T
Thunderbrook 已提交
127 128 129
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
130
  rwlock_.reset(new phi::RWLock);
T
Thunderbrook 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
144
template <typename StreamType>
T
Thunderbrook 已提交
145
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
146
                                      size_t len, StreamType stream) {
T
Thunderbrook 已提交
147 148 149 150 151 152 153 154
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

155
template <typename KeyType, typename ValType>
156
template <typename StreamType>
157
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
158
                                      size_t len, StreamType stream) {
159 160 161 162 163 164 165 166
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
167
template <typename KeyType, typename ValType>
168
template <typename StreamType>
T
Thunderbrook 已提交
169 170
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
171
                                         StreamType stream) {
T
Thunderbrook 已提交
172 173 174 175 176 177 178 179
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

180
template <typename KeyType, typename ValType>
181
template <typename StreamType>
182 183
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
                                         char* pool, size_t start_index,
184
                                         StreamType stream) {
185 186 187 188 189 190
  if (len == 0) {
    return;
  }
  if (pool == NULL) {
    return;
  }
191
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
192 193 194 195
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
                                                       pool, start_index);
}

T
Thunderbrook 已提交
196
template <typename KeyType, typename ValType>
197 198
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
T
Thunderbrook 已提交
199
  container_->prefetch(cudaCpuDeviceId, stream);
T
Thunderbrook 已提交
200
  std::vector<std::thread> threads;
T
Thunderbrook 已提交
201 202 203
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
T
Thunderbrook 已提交
204 205 206 207 208 209 210 211 212 213 214 215

  int thread_num = 8;
  int len_per_thread = num / thread_num;
  int remain = num % thread_num;
  int begin = 0;

  auto dump_func = [unuse_key, kv](int left, int right) {
    for (int i = left; i < right; i++) {
      if (kv[i].first == unuse_key) {
        continue;
      }
      ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
216
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
      auto* downpour_value =
          (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[1] = gpu_val.delta_score;
      cpu_val[2] = gpu_val.show;
      cpu_val[3] = gpu_val.clk;
      cpu_val[4] = gpu_val.lr;
      cpu_val[5] = gpu_val.lr_g2sum;
      cpu_val[6] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
235
      }
T
Thunderbrook 已提交
236 237
#endif
#ifdef PADDLE_WITH_PSCORE
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
      auto* downpour_value =
          (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr);
      int downpour_value_size = downpour_value->size();
      if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
        downpour_value->resize(gpu_val.mf_size + downpour_value_size);
      }
      float* cpu_val = downpour_value->data();
      // cpu_val[0] = 0;
      cpu_val[2] = gpu_val.delta_score;
      cpu_val[3] = gpu_val.show;
      cpu_val[4] = gpu_val.clk;
      cpu_val[5] = gpu_val.lr;
      cpu_val[6] = gpu_val.lr_g2sum;
      cpu_val[0] = gpu_val.slot;
      if (gpu_val.mf_size > 0) {
        for (int x = 0; x < gpu_val.mf_size; x++) {
          cpu_val[x + 7] = gpu_val.mf[x];
        }
T
Thunderbrook 已提交
256
      }
T
Thunderbrook 已提交
257
#endif
T
Thunderbrook 已提交
258 259 260 261 262 263 264 265 266 267
    }
  };

  for (int i = 0; i < thread_num; i++) {
    threads.push_back(std::thread(
        dump_func, begin, begin + len_per_thread + (i < remain ? 1 : 0)));
    begin += len_per_thread + (i < remain ? 1 : 0);
  }
  for (std::thread& t : threads) {
    t.join();
T
Thunderbrook 已提交
268 269
  }

T
Thunderbrook 已提交
270
  // container_->prefetch(devid, stream);
T
Thunderbrook 已提交
271 272
}

T
Thunderbrook 已提交
273
template <typename KeyType, typename ValType>
274
template <typename GradType, typename Sgd, typename StreamType>
T
Thunderbrook 已提交
275 276
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
277
                                         Sgd sgd, StreamType stream) {
T
Thunderbrook 已提交
278 279 280 281 282 283 284 285
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_grads, len, sgd);
}

286
template <typename KeyType, typename ValType>
287
template <typename Sgd, typename StreamType>
288 289
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
290
                                         Sgd sgd, StreamType stream) {
291 292 293 294 295 296 297 298
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
template class HashTable<unsigned long, paddle::framework::FeatureValue>;

template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    cudaStream_t>(const unsigned long* d_keys,
                  paddle::framework::FeatureValue* d_vals, size_t len,
                  cudaStream_t stream);

// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<cudaStream_t>(
//    const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t
//    stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeatureValue* d_vals, size_t len,
                  cudaStream_t stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
//    cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
//                  size_t start_index, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue,
    Optimizer<paddle::framework::FeatureValue,
              paddle::framework::FeaturePushValue>,
    cudaStream_t>(const unsigned long* d_keys,
                  const paddle::framework::FeaturePushValue* d_grads,
                  size_t len, Optimizer<paddle::framework::FeatureValue,
                                        paddle::framework::FeaturePushValue>
                                  sgd,
                  cudaStream_t stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    Optimizer<paddle::framework::FeatureValue,
//              paddle::framework::FeaturePushValue>,
//    cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t
//    len,
//                  Optimizer<paddle::framework::FeatureValue,
//                            paddle::framework::FeaturePushValue>
//                      sgd,
//                  cudaStream_t stream);

#endif
T
Thunderbrook 已提交
347 348 349
}  // end namespace framework
}  // end namespace paddle
#endif