ps_gpu_wrapper.cu 14.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
16 17 18 19
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
20

Y
yaoxuefeng 已提交
21
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
22 23
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
24
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
T
Thunderbrook 已提交
25 26 27 28

namespace paddle {
namespace framework {

29 30 31 32 33 34 35
__global__ void PullCopy(float** dest,
                         const FeatureValue* src,
                         const int64_t* len,
                         int hidden,
                         int slot_num,
                         int total_len,
                         uint64_t** keys) {
T
Thunderbrook 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    if (*(keys[x] + y) == 0) {
      *(dest[x] + y * hidden) = 0;
      *(dest[x] + y * hidden + 1) = 0;
      *(dest[x] + y * hidden + 2) = 0;
    } else {
      *(dest[x] + y * hidden) = (src + i)->show;
      *(dest[x] + y * hidden + 1) = (src + i)->clk;
      *(dest[x] + y * hidden + 2) = (src + i)->lr;
    }
    if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) {
Y
yaoxuefeng 已提交
58
      for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
59 60 61
        *(dest[x] + y * hidden + 3 + j) = 0;
      }
    } else {
Y
yaoxuefeng 已提交
62
      for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
63 64 65 66 67 68
        *(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j];
      }
    }
  }
}

69 70 71 72 73 74 75 76
__global__ void PullCopy(float** dest,
                         const FeatureValue* src,
                         const int64_t* len,
                         int slot_num,
                         int total_len,
                         uint64_t** keys,
                         uint64_t max_val_size,
                         int* gpu_dim) {
Y
yaoxuefeng 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    FeatureValue* feature_value_ptr =
        (FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
    int mf_dim = gpu_dim[x] - 3;
    if (*(keys[x] + y) == 0) {
      *(dest[x] + y * (mf_dim + 3)) = 0;
      *(dest[x] + y * (mf_dim + 3) + 1) = 0;
      *(dest[x] + y * (mf_dim + 3) + 2) = 0;
    } else {
      *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show;
      *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk;
      *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr;
    }
    if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) {
      for (int j = 0; j < mf_dim; j++) {
        *(dest[x] + y * (mf_dim + 3) + 3 + j) = 0;
      }
    } else {
      for (int j = 0; j < mf_dim; j++) {
        *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j];
      }
    }
  }
}

113 114 115 116
__global__ void CopyKeysKernel(uint64_t** src_keys,
                               uint64_t* dest_total_keys,
                               const int64_t* len,
                               int slot_num,
T
Thunderbrook 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
                               int total_len) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    dest_total_keys[i] = src_keys[x][y];
  }
}

134 135 136 137 138 139 140
__global__ void PushCopy(FeaturePushValue* dest,
                         float** src,
                         int64_t* len,
                         int hidden,
                         int slot_num,
                         int total_len,
                         int bs,
T
Thunderbrook 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
                         int* slot_vector) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    (dest + i)->slot = slot_vector[x];
    (dest + i)->show = *(src[x] + y * hidden);
    (dest + i)->clk = *(src[x] + y * hidden + 1);
    (dest + i)->lr_g = *(src[x] + y * hidden + 2) * -1. * bs;
Y
yaoxuefeng 已提交
158
    for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
159 160 161 162 163
      (dest + i)->mf_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
    }
  }
}

164 165 166 167 168 169 170 171
__global__ void PushCopyWithPool(FeaturePushValue* dest,
                                 float** src,
                                 int64_t* len,
                                 int slot_num,
                                 uint64_t total_len,
                                 int bs,
                                 int* slot_vector,
                                 int* mf_dim_vector,
Y
yaoxuefeng 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
                                 size_t grad_value_size) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    FeaturePushValue* cur =
        (FeaturePushValue*)((char*)dest + i * grad_value_size);
    cur->slot = slot_vector[x];
    int mf_dim = mf_dim_vector[x];
    cur->mf_dim = mf_dim;
    cur->show = *(src[x] + y * (mf_dim + 3));
    cur->clk = *(src[x] + y * (mf_dim + 3) + 1);
    cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
    for (int j = 0; j < cur->mf_dim; j++) {
      cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
    }
  }
}
F
Fan Zhang 已提交
198 199
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }

T
Thunderbrook 已提交
200 201 202 203
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
                               uint64_t** gpu_keys,
                               const std::vector<float*>& values,
                               const FeatureValue* total_values_gpu,
204 205
                               const int64_t* gpu_len,
                               const int slot_num,
T
Thunderbrook 已提交
206 207 208
                               const int hidden_size,
                               const int64_t total_length) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
209
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
210
                    ->stream();
211
  auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
T
Thunderbrook 已提交
212
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
213 214 215
  cudaMemcpy(gpu_values,
             values.data(),
             values.size() * sizeof(float*),
T
Thunderbrook 已提交
216 217
             cudaMemcpyHostToDevice);

218
  PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
219 220 221 222 223 224 225
      gpu_values,
      total_values_gpu,
      gpu_len,
      hidden_size,
      slot_num,
      total_length,
      gpu_keys);
T
Thunderbrook 已提交
226 227 228
  cudaStreamSynchronize(stream);
}

Y
yaoxuefeng 已提交
229 230 231 232
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
                               uint64_t** gpu_keys,
                               const std::vector<float*>& values,
                               const FeatureValue* total_values_gpu,
233 234
                               const int64_t* gpu_len,
                               const int slot_num,
Y
yaoxuefeng 已提交
235
                               const int hidden_size,
236 237
                               const int64_t total_length,
                               int* gpu_dim) {
Y
yaoxuefeng 已提交
238 239 240 241 242
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
243 244 245
  cudaMemcpy(gpu_values,
             values.data(),
             values.size() * sizeof(float*),
Y
yaoxuefeng 已提交
246 247
             cudaMemcpyHostToDevice);
  PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
248 249 250 251 252 253 254 255
      gpu_values,
      total_values_gpu,
      gpu_len,
      slot_num,
      total_length,
      gpu_keys,
      val_type_size_,
      gpu_dim);
Y
yaoxuefeng 已提交
256 257 258
  cudaStreamSynchronize(stream);
}

T
Thunderbrook 已提交
259
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
260 261 262 263
                            uint64_t** origin_keys,
                            uint64_t* total_keys,
                            const int64_t* gpu_len,
                            int slot_num,
T
Thunderbrook 已提交
264 265
                            int total_len) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
266
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
267
                    ->stream();
268
  CopyKeysKernel<<<(total_len + 1024 - 1) / 1024, 1024, 0, stream>>>(
T
Thunderbrook 已提交
269 270 271 272 273 274 275 276 277 278 279 280
      origin_keys, total_keys, gpu_len, slot_num, total_len);
  cudaStreamSynchronize(stream);
}

void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
                               const std::vector<const float*>& grad_values,
                               FeaturePushValue* total_grad_values_gpu,
                               const std::vector<int64_t>& slot_lengths,
                               const int hidden_size,
                               const int64_t total_length,
                               const int batch_size) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
281
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
282 283 284 285 286 287
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
288 289
      memory::Alloc(place, grad_values.size() * sizeof(float*));
  auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
T
Thunderbrook 已提交
290
  auto buf_slot_vector =
291
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
T
Thunderbrook 已提交
292 293 294 295 296

  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());

297 298 299 300 301 302 303 304 305 306 307 308
  cudaMemcpy(gpu_values,
             grad_values.data(),
             grad_values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len,
             slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t),
             cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector,
             slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int),
             cudaMemcpyHostToDevice);
T
Thunderbrook 已提交
309

310
  PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
311 312 313 314 315 316 317 318
      total_grad_values_gpu,
      gpu_values,
      gpu_len,
      hidden_size,
      slot_lengths.size(),
      total_length,
      batch_size,
      d_slot_vector);
T
Thunderbrook 已提交
319 320
  cudaStreamSynchronize(stream);
}
Y
yaoxuefeng 已提交
321

Y
yaoxuefeng 已提交
322 323 324 325 326
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
                               const std::vector<const float*>& grad_values,
                               FeaturePushValue* total_grad_values_gpu,
                               const std::vector<int64_t>& slot_lengths,
                               const uint64_t total_length,
327 328
                               const int batch_size,
                               size_t grad_value_size) {
Y
yaoxuefeng 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
      memory::Alloc(place, grad_values.size() * sizeof(float*));
  auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
  auto buf_slot_vector =
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
  auto buf_mf_dim_vector =
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
  int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
  cudaMemcpy(gpu_values,
             grad_values.data(),
             grad_values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len,
             slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t),
             cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector,
             slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int),
             cudaMemcpyHostToDevice);
  cudaMemcpy(d_mf_dim_vector,
             slot_mf_dim_vector_.data(),
             slot_lengths_lod.size() * sizeof(int),
             cudaMemcpyHostToDevice);
Y
yaoxuefeng 已提交
363
  PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
364 365 366 367 368 369 370 371
      total_grad_values_gpu,
      gpu_values,
      gpu_len,
      slot_lengths.size(),
      total_length,
      batch_size,
      d_slot_vector,
      d_mf_dim_vector,
Y
yaoxuefeng 已提交
372 373 374 375
      grad_value_size);
  cudaStreamSynchronize(stream);
}

376 377 378 379 380 381
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff,
                                float clk_coeff,
                                float min_bound,
                                float max_bound,
                                float learning_rate,
                                float initial_g2sum,
Y
yaoxuefeng 已提交
382
                                float initial_range) {
Z
zmxdream 已提交
383
  OptimizerConfig optimizer_config;
384 385 386 387 388 389 390
  optimizer_config.set_sparse_sgd(nonclk_coeff,
                                  clk_coeff,
                                  min_bound,
                                  max_bound,
                                  learning_rate,
                                  initial_g2sum,
                                  initial_range);
Z
zmxdream 已提交
391
  HeterPs_->set_sparse_sgd(optimizer_config);
Y
yaoxuefeng 已提交
392 393 394
}

void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
395 396 397 398
                                float mf_learning_rate,
                                float mf_initial_g2sum,
                                float mf_initial_range,
                                float mf_min_bound,
Y
yaoxuefeng 已提交
399
                                float mf_max_bound) {
Z
zmxdream 已提交
400
  OptimizerConfig optimizer_config;
401 402 403 404 405 406
  optimizer_config.set_embedx_sgd(mf_create_thresholds,
                                  mf_learning_rate,
                                  mf_initial_g2sum,
                                  mf_initial_range,
                                  mf_min_bound,
                                  mf_max_bound);
Z
zmxdream 已提交
407
  HeterPs_->set_embedx_sgd(optimizer_config);
Y
yaoxuefeng 已提交
408 409
}

T
Thunderbrook 已提交
410 411 412
}  // end namespace framework
}  // end namespace paddle
#endif