ps_gpu_wrapper.cu 13.4 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
16 17 18 19
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
Y
yaoxuefeng 已提交
20
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
T
Thunderbrook 已提交
21 22
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
23
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
T
Thunderbrook 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

namespace paddle {
namespace framework {

__global__ void PullCopy(float** dest, const FeatureValue* src,
                         const int64_t* len, int hidden, int slot_num,
                         int total_len, uint64_t** keys) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    if (*(keys[x] + y) == 0) {
      *(dest[x] + y * hidden) = 0;
      *(dest[x] + y * hidden + 1) = 0;
      *(dest[x] + y * hidden + 2) = 0;
    } else {
      *(dest[x] + y * hidden) = (src + i)->show;
      *(dest[x] + y * hidden + 1) = (src + i)->clk;
      *(dest[x] + y * hidden + 2) = (src + i)->lr;
    }
    if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) {
Y
yaoxuefeng 已提交
53
      for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
54 55 56
        *(dest[x] + y * hidden + 3 + j) = 0;
      }
    } else {
Y
yaoxuefeng 已提交
57
      for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
58 59 60 61 62 63
        *(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j];
      }
    }
  }
}

Y
yaoxuefeng 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
__global__ void PullCopy(float** dest, const FeatureValue* src,
                         const int64_t* len, int slot_num, int total_len,
                         uint64_t** keys, uint64_t max_val_size, int* gpu_dim) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    FeatureValue* feature_value_ptr =
        (FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size));
    int mf_dim = gpu_dim[x] - 3;
    if (*(keys[x] + y) == 0) {
      *(dest[x] + y * (mf_dim + 3)) = 0;
      *(dest[x] + y * (mf_dim + 3) + 1) = 0;
      *(dest[x] + y * (mf_dim + 3) + 2) = 0;
    } else {
      *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show;
      *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk;
      *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr;
    }
    if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) {
      for (int j = 0; j < mf_dim; j++) {
        *(dest[x] + y * (mf_dim + 3) + 3 + j) = 0;
      }
    } else {
      for (int j = 0; j < mf_dim; j++) {
        *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j];
      }
    }
  }
}

T
Thunderbrook 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
                               const int64_t* len, int slot_num,
                               int total_len) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    dest_total_keys[i] = src_keys[x][y];
  }
}

__global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
                         int hidden, int slot_num, int total_len, int bs,
                         int* slot_vector) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    (dest + i)->slot = slot_vector[x];
    (dest + i)->show = *(src[x] + y * hidden);
    (dest + i)->clk = *(src[x] + y * hidden + 1);
    (dest + i)->lr_g = *(src[x] + y * hidden + 2) * -1. * bs;
Y
yaoxuefeng 已提交
141
    for (int j = 0; j < hidden - 3; j++) {
T
Thunderbrook 已提交
142 143 144 145 146
      (dest + i)->mf_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
    }
  }
}

Y
yaoxuefeng 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
__global__ void PushCopyWithPool(FeaturePushValue* dest, float** src,
                                 int64_t* len, int slot_num, uint64_t total_len,
                                 int bs, int* slot_vector, int* mf_dim_vector,
                                 size_t grad_value_size) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    FeaturePushValue* cur =
        (FeaturePushValue*)((char*)dest + i * grad_value_size);
    cur->slot = slot_vector[x];
    int mf_dim = mf_dim_vector[x];
    cur->mf_dim = mf_dim;
    cur->show = *(src[x] + y * (mf_dim + 3));
    cur->clk = *(src[x] + y * (mf_dim + 3) + 1);
    cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
    for (int j = 0; j < cur->mf_dim; j++) {
      cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
    }
  }
}
F
Fan Zhang 已提交
176 177
PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; }

T
Thunderbrook 已提交
178 179 180 181 182 183 184 185
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
                               uint64_t** gpu_keys,
                               const std::vector<float*>& values,
                               const FeatureValue* total_values_gpu,
                               const int64_t* gpu_len, const int slot_num,
                               const int hidden_size,
                               const int64_t total_length) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
186
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
187
                    ->stream();
188
  auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
T
Thunderbrook 已提交
189 190 191 192
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
  cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);

193
  PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
T
Thunderbrook 已提交
194 195 196 197 198
      gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num,
      total_length, gpu_keys);
  cudaStreamSynchronize(stream);
}

Y
yaoxuefeng 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
                               uint64_t** gpu_keys,
                               const std::vector<float*>& values,
                               const FeatureValue* total_values_gpu,
                               const int64_t* gpu_len, const int slot_num,
                               const int hidden_size,
                               const int64_t total_length, int* gpu_dim) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
  cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
  PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
      gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys,
      val_type_size_, gpu_dim);
  cudaStreamSynchronize(stream);
}

T
Thunderbrook 已提交
219 220 221 222 223
void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
                            uint64_t** origin_keys, uint64_t* total_keys,
                            const int64_t* gpu_len, int slot_num,
                            int total_len) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
224
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
225
                    ->stream();
226
  CopyKeysKernel<<<(total_len + 1024 - 1) / 1024, 1024, 0, stream>>>(
T
Thunderbrook 已提交
227 228 229 230 231 232 233 234 235 236 237 238
      origin_keys, total_keys, gpu_len, slot_num, total_len);
  cudaStreamSynchronize(stream);
}

void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
                               const std::vector<const float*>& grad_values,
                               FeaturePushValue* total_grad_values_gpu,
                               const std::vector<int64_t>& slot_lengths,
                               const int hidden_size,
                               const int64_t total_length,
                               const int batch_size) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
239
                    platform::DeviceContextPool::Instance().Get(place))
T
Thunderbrook 已提交
240 241 242 243 244 245
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
246 247
      memory::Alloc(place, grad_values.size() * sizeof(float*));
  auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
T
Thunderbrook 已提交
248
  auto buf_slot_vector =
249
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
T
Thunderbrook 已提交
250 251 252 253 254 255 256 257 258 259 260 261

  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());

  cudaMemcpy(gpu_values, grad_values.data(),
             grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len, slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector, slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);

262
  PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
T
Thunderbrook 已提交
263 264 265 266
      total_grad_values_gpu, gpu_values, gpu_len, hidden_size,
      slot_lengths.size(), total_length, batch_size, d_slot_vector);
  cudaStreamSynchronize(stream);
}
Y
yaoxuefeng 已提交
267

Y
yaoxuefeng 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
                               const std::vector<const float*>& grad_values,
                               FeaturePushValue* total_grad_values_gpu,
                               const std::vector<int64_t>& slot_lengths,
                               const uint64_t total_length,
                               const int batch_size, size_t grad_value_size) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
      memory::Alloc(place, grad_values.size() * sizeof(float*));
  auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
  auto buf_slot_vector =
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
  auto buf_mf_dim_vector =
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
  int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
  cudaMemcpy(gpu_values, grad_values.data(),
             grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len, slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector, slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
  cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector_.data(),
             slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
  PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
      total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(),
      total_length, batch_size, d_slot_vector, d_mf_dim_vector,
      grad_value_size);
  cudaStreamSynchronize(stream);
}

Y
yaoxuefeng 已提交
307 308 309 310
void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff,
                                float min_bound, float max_bound,
                                float learning_rate, float initial_g2sum,
                                float initial_range) {
Z
zmxdream 已提交
311 312 313 314
  OptimizerConfig optimizer_config;
  optimizer_config.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, max_bound,
                                  learning_rate, initial_g2sum, initial_range);
  HeterPs_->set_sparse_sgd(optimizer_config);
Y
yaoxuefeng 已提交
315 316 317 318 319 320
}

void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds,
                                float mf_learning_rate, float mf_initial_g2sum,
                                float mf_initial_range, float mf_min_bound,
                                float mf_max_bound) {
Z
zmxdream 已提交
321 322 323 324 325
  OptimizerConfig optimizer_config;
  optimizer_config.set_embedx_sgd(mf_create_thresholds, mf_learning_rate,
                                  mf_initial_g2sum, mf_initial_range,
                                  mf_min_bound, mf_max_bound);
  HeterPs_->set_embedx_sgd(optimizer_config);
Y
yaoxuefeng 已提交
326 327
}

T
Thunderbrook 已提交
328 329 330
}  // end namespace framework
}  // end namespace paddle
#endif