hashtable_kernel.kps 12.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
17
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
18 19 20 21 22 23

namespace paddle {
namespace framework {

#if defined(PADDLE_WITH_XPU_KP)

24 25 26
__device__ void update_lr(OptimizerConfig& optimizer_config, float& w,
                          float& g2sum,
                          float g,  // NOLINT
27
                          float scale) {
28 29 30 31
  float local_learning_rate = optimizer_config.learning_rate;
  float local_initial_g2sum = optimizer_config.initial_g2sum;
  float local_min_bound = optimizer_config.min_bound;
  float local_max_bound = optimizer_config.max_bound;
32 33 34 35 36 37

  double add_g2sum = 0;
  double ratio = local_learning_rate *
                 sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum));
  double scaled_grad = g / scale;

38
  w += scaled_grad * ratio;
39 40 41 42 43 44

  if (w < local_min_bound) w = local_min_bound;
  if (w > local_max_bound) w = local_max_bound;

  add_g2sum += scaled_grad * scaled_grad;

45
  g2sum += add_g2sum;
46 47
}

48 49 50 51 52 53
__device__ void update_mf(OptimizerConfig& optimizer_config, int n, float* w,
                          float& g2sum, const float* g, float scale) {
  float local_mf_learning_rate = optimizer_config.mf_learning_rate;
  float local_mf_initial_g2sum = optimizer_config.mf_initial_g2sum;
  float local_mf_min_bound = optimizer_config.mf_min_bound;
  float local_mf_max_bound = optimizer_config.mf_max_bound;
54 55 56 57 58 59 60 61 62 63 64 65 66 67

  double add_g2sum = 0;
  double ratio =
      local_mf_learning_rate *
      sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + g2sum));
  for (int i = 0; i < n; ++i) {
    double scaled_grad = g[i] / scale;
    w[i] += scaled_grad * ratio;

    if (w[i] < local_mf_min_bound) w[i] = local_mf_min_bound;
    if (w[i] > local_mf_max_bound) w[i] = local_mf_max_bound;
    add_g2sum += scaled_grad * scaled_grad;
  }

68
  g2sum += add_g2sum / n;
69 70 71 72 73
}

__device__ float xpu_rand_uniform() { return 0.1; }

template <typename ValType, typename GradType>
74 75
__device__ void update_value(OptimizerConfig& optimizer_config, ValType& val,
                             const GradType& grad) {  // NOLINT
76 77 78
  val.slot = grad.slot;
  val.show += grad.show;
  val.clk += grad.clk;
79

80 81
  float local_nonclk_coeff = optimizer_config.nonclk_coeff;
  float local_clk_coeff = optimizer_config.clk_coeff;
82

83 84
  float local_mf_create_thresholds = optimizer_config.mf_create_thresholds;
  float local_mf_initial_range = optimizer_config.mf_initial_range;
85

86 87
  val.delta_score +=
      local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;
88

89
  update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show);
90 91 92

  if (val.mf_size == 0) {
    if (local_mf_create_thresholds <=
93
        local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) {
94 95 96 97
      val.mf_size = MF_DIM + 1;
      val.mf[0] = 0;

      for (int i = 0; i < MF_DIM; ++i) {
98
        val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
99 100 101
      }
    }
  } else {
102 103
    update_mf(optimizer_config, MF_DIM, &val.mf[1], val.mf[0], grad.mf_g,
              grad.show);
104 105 106 107
  }
}

template <typename KeyType, typename ValType, typename Table>
108
__global__ void insert_kernel(Table& table, const KeyType* const keys,
F
Fan Zhang 已提交
109
                              const ValType* const vals, long long len) {
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();

  const int buf_size = 150;
  __local__ KeyType local_keys[buf_size];
  __local__ ValType local_vals[buf_size];
  int len_per_loop = min(buf_size, roundup_div(len, nthreads));

  for (int i = thread_id * len_per_loop; i < len;
       i += nthreads * len_per_loop) {
    int read_len = min(len_per_loop, len - i);
    GM2LM(keys, local_keys, read_len * sizeof(KeyType));
    GM2LM(vals, local_vals, read_len * sizeof(ValType));
    for (int k = 0; k < read_len; k++) {
129 130
      auto status = table.insert(local_keys[k], local_vals[k]);
      assert(status != false && "error: insert fails: table is full");
131 132 133 134 135
    }
  }
}

template <typename KeyType, typename ValType, typename Table>
136
__global__ void search_kernel(Table& table, const KeyType* const keys,
F
Fan Zhang 已提交
137
                              ValType* const vals, long long len) {
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();

  const int buf_size = 150;
  __local__ KeyType local_keys[buf_size];
  __local__ ValType local_vals[buf_size];

  int len_per_loop = min(buf_size, roundup_div(len, nthreads));
  for (int i = thread_id * len_per_loop; i < len;
       i += nthreads * len_per_loop) {
    int read_len = min(len_per_loop, len - i);
    GM2LM(keys, local_keys, read_len * sizeof(KeyType));
    for (int k = 0; k < read_len; k++) {
156 157 158 159
      ValType* val = table.find(local_keys[k]);
      if (val != NULL) {
        local_vals[k] = *val;
      }
160 161 162 163 164 165
    }
    LM2GM(local_vals, vals + i, read_len * sizeof(ValType));
  }
}

template <typename KeyType, typename ValType, typename Table, typename GradType>
Z
zmxdream 已提交
166
__global__ void update_kernel(Table& table, OptimizerConfig& optimizer_config,
167
                              const KeyType* const keys,
F
Fan Zhang 已提交
168
                              const GradType* const grads, long long len) {
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();

  const int buf_size = 250;
  __local__ KeyType local_keys[buf_size];
  __local__ GradType local_grads[buf_size];

  int len_per_loop = min(buf_size, roundup_div(len, nthreads));
  for (int i = thread_id * len_per_loop; i < len;
       i += nthreads * len_per_loop) {
    int read_len = min(len_per_loop, len - i);

    GM2LM(keys, local_keys, read_len * sizeof(KeyType));
    GM2LM(grads, local_grads, read_len * sizeof(GradType));

    for (int k = 0; k < read_len; k++) {
190 191 192 193
      ValType* val = table.find(local_keys[k]);
      if (val != NULL) {
        update_value(optimizer_config, *val, local_grads[i]);
      }
194 195 196 197 198 199 200 201 202
    }
  }
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  auto tmp_container = XPUCacheArray<KeyType, ValType>(capacity);
  xpu_malloc(reinterpret_cast<void**>(&container_),
             sizeof(XPUCacheArray<KeyType, ValType>));
203
  xpu_memcpy((void*)container_, &tmp_container,
204
             sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);
Z
zmxdream 已提交
205
  xpu_malloc(reinterpret_cast<void**>(&device_optimizer_config_),
206
             sizeof(OptimizerConfig));
Z
zmxdream 已提交
207
  xpu_memcpy((void*)device_optimizer_config_, &host_optimizer_config_,
208 209
             sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);

210 211 212 213 214 215
  rwlock_.reset(new phi::RWLock);
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  xpu_free((void*)container_);
Z
zmxdream 已提交
216
  xpu_free((void*)device_optimizer_config_);
217 218 219 220 221 222 223
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

224 225 226
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
    const OptimizerConfig& optimizer_config) {
Z
zmxdream 已提交
227 228
  host_optimizer_config_.set_sparse_sgd(optimizer_config);
  xpu_memcpy((void*)device_optimizer_config_, &host_optimizer_config_,
229 230 231 232 233 234
             sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
    const OptimizerConfig& optimizer_config) {
Z
zmxdream 已提交
235 236
  host_optimizer_config_.set_embedx_sgd(optimizer_config);
  xpu_memcpy((void*)device_optimizer_config_, &host_optimizer_config_,
237 238 239
             sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}

240 241 242 243 244 245 246
template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
                                      size_t len, StreamType stream) {
  if (len == 0) {
    return;
  }
F
Fan Zhang 已提交
247 248 249
  long long c_len = (long long)len;
  search_kernel<KeyType, ValType,
                XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
250
      *container_, d_keys, d_vals, c_len);
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
                                      size_t len, StreamType stream) {
  if (len == 0) {
    return;
  }
  // TODO(zhangminxu): to be implemented
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
F
Fan Zhang 已提交
271 272 273
  long long c_len = (long long)len;
  insert_kernel<KeyType, ValType,
                XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
274
      *container_, d_keys, d_vals, c_len);
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
  // TODO(zhangminxu): to be implemented
}

template <typename KeyType, typename ValType>
template <typename GradType, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
F
Fan Zhang 已提交
291 292
  long long c_len = (long long)len;
  update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>,
293
                GradType><<<4, 64, stream>>>(
Z
zmxdream 已提交
294
      *container_, *device_optimizer_config_, d_keys, d_grads, c_len);
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
  // TODO(zhangminxu): to be implemented
}

template class HashTable<unsigned long, paddle::framework::FeatureValue>;

template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    XPUStream>(const unsigned long* d_keys,
               paddle::framework::FeatureValue* d_vals, size_t len,
               XPUStream stream);

// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<XPUStream>(
//    const unsigned long* d_keys, char* d_vals, size_t len, XPUStream stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    XPUStream>(const unsigned long* d_keys,
               const paddle::framework::FeatureValue* d_vals, size_t len,
               XPUStream stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
//    XPUStream>(const unsigned long* d_keys, size_t len, char* pool,
//               size_t start_index, XPUStream stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<XPUStream>(int devid, XPUStream stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue, XPUStream>(
    const unsigned long* d_keys,
    const paddle::framework::FeaturePushValue* d_grads, size_t len,
    XPUStream stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    XPUStream>(const unsigned long* d_keys, const char* d_grads,
//                          size_t len, XPUStream stream);

#endif
}  // end namespace framework
}  // end namespace paddle
#endif