hashtable_kernel.kps 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"

namespace optimizer_config {
extern _global_ptr_ float* nonclk_coeff;
extern _global_ptr_ float* clk_coeff;

extern _global_ptr_ float* min_bound;
extern _global_ptr_ float* max_bound;
extern _global_ptr_ float* learning_rate;
extern _global_ptr_ float* initial_g2sum;
extern _global_ptr_ float* initial_range;

extern _global_ptr_ float* mf_create_thresholds;
extern _global_ptr_ float* mf_learning_rate;
extern _global_ptr_ float* mf_initial_g2sum;
extern _global_ptr_ float* mf_initial_range;
extern _global_ptr_ float* mf_min_bound;
extern _global_ptr_ float* mf_max_bound;
}

namespace paddle {
namespace framework {

#if defined(PADDLE_WITH_XPU_KP)

41
__device__ void update_lr(float& w, float& g2sum, float g,  // NOLINT
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
                          float scale) {
  __local__ float local_learning_rate;
  __local__ float local_initial_g2sum;
  __local__ float local_min_bound;
  __local__ float local_max_bound;

  GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float));
  GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float));
  GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float));
  GM2LM(optimizr_config::max_bound, &local_max_bound, sizeof(float));

  double add_g2sum = 0;
  double ratio = local_learning_rate *
                 sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum));
  double scaled_grad = g / scale;

58
  w += scaled_grad * ratio;
59 60 61 62 63 64

  if (w < local_min_bound) w = local_min_bound;
  if (w > local_max_bound) w = local_max_bound;

  add_g2sum += scaled_grad * scaled_grad;

65
  g2sum += add_g2sum;
66 67
}

68
__device__ void update_mf(int n, float* w, float& g2sum, const float* g,
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
                          float scale) {
  __local__ float local_mf_learning_rate;
  __local__ float local_mf_initial_g2sum;
  __local__ float local_mf_min_bound;
  __local__ float local_mf_max_bound;

  GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate,
        sizeof(float));
  GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum,
        sizeof(float));
  GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float));
  GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float));

  double add_g2sum = 0;
  double ratio =
      local_mf_learning_rate *
      sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + g2sum));
  for (int i = 0; i < n; ++i) {
    double scaled_grad = g[i] / scale;
    w[i] += scaled_grad * ratio;

    if (w[i] < local_mf_min_bound) w[i] = local_mf_min_bound;
    if (w[i] > local_mf_max_bound) w[i] = local_mf_max_bound;
    add_g2sum += scaled_grad * scaled_grad;
  }

95
  g2sum += add_g2sum / n;
96 97 98 99 100
}

__device__ float xpu_rand_uniform() { return 0.1; }

template <typename ValType, typename GradType>
101 102 103 104
__device__ void update_value(ValType& val, const GradType& grad) {  // NOLINT
  val.slot = grad.slot;
  val.show += grad.show;
  val.clk += grad.clk;
105 106 107 108 109 110 111 112 113 114 115 116

  __local__ float local_nonclk_coeff;
  __local__ float local_clk_coeff;

  __local__ float local_mf_create_thresholds;
  __local__ float local_mf_initial_range;

  GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float));
  GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float));
  GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
        sizeof(float));

117 118
  val.delta_score +=
      local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;
119

120
  update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);
121 122 123

  if (val.mf_size == 0) {
    if (local_mf_create_thresholds <=
124
        local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) {
125 126 127 128
      val.mf_size = MF_DIM + 1;
      val.mf[0] = 0;

      for (int i = 0; i < MF_DIM; ++i) {
129
        val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
130 131 132
      }
    }
  } else {
133
    update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
  }
}

template <typename KeyType, typename ValType, typename Table>
__global__ void insert_kernel(Table* table, const KeyType* const keys,
                              const ValType* const vals, size_t len) {
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();

  const int buf_size = 150;
  __local__ KeyType local_keys[buf_size];
  __local__ ValType local_vals[buf_size];
  int len_per_loop = min(buf_size, roundup_div(len, nthreads));

  for (int i = thread_id * len_per_loop; i < len;
       i += nthreads * len_per_loop) {
    int read_len = min(len_per_loop, len - i);
    GM2LM(keys, local_keys, read_len * sizeof(KeyType));
    GM2LM(vals, local_vals, read_len * sizeof(ValType));
    for (int k = 0; k < read_len; k++) {
      // auto status = table->insert(local_keys[k], local_vals[k]);
      // assert(status != false && "error: insert fails: table is full");
    }
  }
}

template <typename KeyType, typename ValType, typename Table>
__global__ void search_kernel(Table* table, const KeyType* const keys,
                              ValType* const vals, size_t len) {
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();

  const int buf_size = 150;
  __local__ KeyType local_keys[buf_size];
  __local__ ValType local_vals[buf_size];

  int len_per_loop = min(buf_size, roundup_div(len, nthreads));
  for (int i = thread_id * len_per_loop; i < len;
       i += nthreads * len_per_loop) {
    int read_len = min(len_per_loop, len - i);
    GM2LM(keys, local_keys, read_len * sizeof(KeyType));
    for (int k = 0; k < read_len; k++) {
      // ValType* val = table->find(local_keys[k]);
      // if (val != NULL) {
      //  local_vals[k] = *val;
      // }
    }
    LM2GM(local_vals, vals + i, read_len * sizeof(ValType));
  }
}

template <typename KeyType, typename ValType, typename Table, typename GradType>
__global__ void update_kernel(Table* table, const KeyType* const keys,
                              const GradType* const grads, size_t len) {
  int cid = core_id();
  int ncores = core_num();
  if (cid >= ncores) {
    return;
  }
  int thread_id = ncores * cluster_id() + cid;
  int nthreads = ncores * cluster_num();

  const int buf_size = 250;
  __local__ KeyType local_keys[buf_size];
  __local__ GradType local_grads[buf_size];

  int len_per_loop = min(buf_size, roundup_div(len, nthreads));
  for (int i = thread_id * len_per_loop; i < len;
       i += nthreads * len_per_loop) {
    int read_len = min(len_per_loop, len - i);

    GM2LM(keys, local_keys, read_len * sizeof(KeyType));
    GM2LM(grads, local_grads, read_len * sizeof(GradType));

    for (int k = 0; k < read_len; k++) {
      // ValType* val = table->find(local_keys[k]);
      // if (val != NULL) {
      //  update_value(*val, grads[i]);
      //}
    }
  }
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  auto tmp_container = XPUCacheArray<KeyType, ValType>(capacity);
  xpu_malloc(reinterpret_cast<void**>(&container_),
             sizeof(XPUCacheArray<KeyType, ValType>));
  xpu_memcpy(container_, &tmp_container,
             sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);
  rwlock_.reset(new phi::RWLock);
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  xpu_free((void*)container_);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
                                      size_t len, StreamType stream) {
  if (len == 0) {
    return;
  }
  search_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len);
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
                                      size_t len, StreamType stream) {
  if (len == 0) {
    return;
  }
  // TODO(zhangminxu): to be implemented
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
  insert_kernel<<<4, 64, stream>>>(container_, d_keys, d_vals, len);
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, StreamType stream) {
  // TODO(zhangminxu): to be implemented
}

template <typename KeyType, typename ValType>
template <typename GradType, typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
  update_kernel<<<4, 64, stream>>>(container_, d_keys, d_grads, len);
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
                                         StreamType stream) {
  if (len == 0) {
    return;
  }
  // TODO(zhangminxu): to be implemented
}

template class HashTable<unsigned long, paddle::framework::FeatureValue>;

template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
    XPUStream>(const unsigned long* d_keys,
               paddle::framework::FeatureValue* d_vals, size_t len,
               XPUStream stream);

// template void
// HashTable<unsigned long, paddle::framework::FeatureValue>::get<XPUStream>(
//    const unsigned long* d_keys, char* d_vals, size_t len, XPUStream stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
    XPUStream>(const unsigned long* d_keys,
               const paddle::framework::FeatureValue* d_vals, size_t len,
               XPUStream stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::insert<
//    XPUStream>(const unsigned long* d_keys, size_t len, char* pool,
//               size_t start_index, XPUStream stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::
    dump_to_cpu<XPUStream>(int devid, XPUStream stream);

template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
    paddle::framework::FeaturePushValue, XPUStream>(
    const unsigned long* d_keys,
    const paddle::framework::FeaturePushValue* d_grads, size_t len,
    XPUStream stream);

// template void HashTable<unsigned long,
// paddle::framework::FeatureValue>::update<
//    XPUStream>(const unsigned long* d_keys, const char* d_grads,
//                          size_t len, XPUStream stream);

#endif
}  // end namespace framework
}  // end namespace paddle
#endif