box_wrapper.cu 13.3 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_BOX_PS
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"

namespace paddle {
namespace framework {

S
ShenLiang 已提交
27 28 29 30 31 32
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PullCopy(
    float** dest,
    const boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* src,
    const int64_t* len, int hidden, int expand_dim, int slot_num, int total_len,
    uint64_t** keys) {
H
hutuxian 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    if (*(keys[x] + y) == 0) {
      *(dest[x] + y * hidden) = 0;
      *(dest[x] + y * hidden + 1) = 0;
      *(dest[x] + y * hidden + 2) = 0;
    } else {
      *(dest[x] + y * hidden) = (src + i)->show;
      *(dest[x] + y * hidden + 1) = (src + i)->clk;
      *(dest[x] + y * hidden + 2) = (src + i)->embed_w;
    }
    if ((src + i)->embedding_size == 0 || *(keys[x] + y) == 0) {
S
ShenLiang 已提交
55
      for (int j = 0; j < hidden - 3; j++) {
H
hutuxian 已提交
56 57 58
        *(dest[x] + y * hidden + 3 + j) = 0;
      }
    } else {
S
ShenLiang 已提交
59
      for (int j = 0; j < hidden - 3; j++) {
H
hutuxian 已提交
60 61 62
        *(dest[x] + y * hidden + 3 + j) = (src + i)->embedx[1 + j];
      }
    }
S
ShenLiang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76
    // process embed_expand
    if (expand_dim > 0) {
      int z = x + slot_num;
      if ((src + i)->embed_expand_size[0] == 0 || *(keys[x] + y) == 0) {
        for (int j = 0; j < expand_dim; j++) {
          *(dest[z] + y * expand_dim + j) = 0;
        }
      } else {
        for (int j = 0; j < expand_dim; j++) {
          *(dest[z] + y * expand_dim + j) = (src + i)->embed_expand[1 + j];
        }
      }
    }
  }  // end kernel loop
H
hutuxian 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
}

__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
                               const int64_t* len, int slot_num,
                               int total_len) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    dest_total_keys[i] = src_keys[x][y];
  }
}

S
ShenLiang 已提交
98 99 100 101 102
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PushCopy(
    boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* dest, float** src,
    int64_t* len, int hidden, int expand_dim, int slot_num, int total_len,
    int bs, int* slot_vector) {
H
hutuxian 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    (dest + i)->slot = slot_vector[x];
    (dest + i)->show = *(src[x] + y * hidden);
    (dest + i)->clk = *(src[x] + y * hidden + 1);
    (dest + i)->embed_g = *(src[x] + y * hidden + 2) * -1. * bs;
S
ShenLiang 已提交
119
    for (int j = 0; j < hidden - 3; j++) {
H
hutuxian 已提交
120 121
      (dest + i)->embedx_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
    }
S
ShenLiang 已提交
122 123 124 125 126 127 128
    if (expand_dim > 0) {
      int z = x + slot_num;
      for (int j = 0; j < expand_dim; j++) {
        (dest + i)->embed_expand_g[j] =
            *(src[z] + y * expand_dim + j) * -1. * bs;
      }
    }
H
hutuxian 已提交
129 130 131 132 133 134
  }
}

void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
                             uint64_t** gpu_keys,
                             const std::vector<float*>& values,
S
ShenLiang 已提交
135 136 137
                             void* total_values_gpu, const int64_t* gpu_len,
                             const int slot_num, const int hidden_size,
                             const int expand_embed_dim,
H
hutuxian 已提交
138 139 140
                             const int64_t total_length) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(
141
                        BOOST_GET_CONST(platform::CUDAPlace, place)))
H
hutuxian 已提交
142 143 144
                    ->stream();
  auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*));
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
145 146 147 148
#ifdef PADDLE_WITH_HIP
  hipMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
            hipMemcpyHostToDevice);
#else
H
hutuxian 已提交
149 150
  cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
151
#endif
S
ShenLiang 已提交
152 153 154 155 156 157 158 159 160 161 162
#define EMBEDX_CASE(i, ...)                                                  \
  case i: {                                                                  \
    constexpr size_t EmbedxDim = i;                                          \
    switch (expand_embed_dim) {                                              \
      __VA_ARGS__                                                            \
      default:                                                               \
        PADDLE_THROW(platform::errors::InvalidArgument(                      \
            "Unsupport this expand embedding size [%d]", expand_embed_dim)); \
    }                                                                        \
  } break

163 164 165 166 167 168 169 170 171 172 173 174 175
#ifdef PADDLE_WITH_HIP
#define EXPAND_EMBED_PUSH_CASE(i, ...)                                        \
  case i: {                                                                   \
    constexpr size_t ExpandDim = i;                                           \
    hipLaunchKernelGGL(                                                       \
        PushCopy<EmbedxDim, ExpandDim>, dim3((total_length + 512 - 1) / 512), \
        dim3(512), 0, stream, gpu_values,                                     \
        reinterpret_cast<boxps::FeatureValueGpu<EmbedxDim, ExpandDim>*>(      \
            total_values_gpu),                                                \
        gpu_len, hidden_size, expand_embed_dim, slot_num, total_length,       \
        gpu_keys);                                                            \
  } break
#else
S
ShenLiang 已提交
176 177 178 179 180 181 182 183 184 185 186
#define EXPAND_EMBED_PULL_CASE(i, ...)                                       \
  case i: {                                                                  \
    constexpr size_t ExpandDim = i;                                          \
    PullCopy<EmbedxDim,                                                      \
             ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
        gpu_values,                                                          \
        reinterpret_cast<boxps::FeatureValueGpu<EmbedxDim, ExpandDim>*>(     \
            total_values_gpu),                                               \
        gpu_len, hidden_size, expand_embed_dim, slot_num, total_length,      \
        gpu_keys);                                                           \
  } break
187
#endif
H
hutuxian 已提交
188

S
ShenLiang 已提交
189 190 191 192 193 194 195 196
  switch (hidden_size - 3) {
    EMBEDX_CASE(8, EXPAND_EMBED_PULL_CASE(0); EXPAND_EMBED_PULL_CASE(8);
                EXPAND_EMBED_PULL_CASE(64););
    EMBEDX_CASE(16, EXPAND_EMBED_PULL_CASE(0););
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupport this embedding size [%d]", hidden_size - 3));
  }
H
hutuxian 已提交
197
  cudaStreamSynchronize(stream);
S
ShenLiang 已提交
198 199
#undef EXPAND_EMBED_PULL_CASE
#undef EMBEDX_CASE
H
hutuxian 已提交
200 201 202 203 204 205 206
}

void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
                          uint64_t** origin_keys, uint64_t* total_keys,
                          const int64_t* gpu_len, int slot_num, int total_len) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(
207
                        BOOST_GET_CONST(platform::CUDAPlace, place)))
H
hutuxian 已提交
208
                    ->stream();
209 210 211 212 213 214
#ifdef PADDLE_WITH_HIP
  hipLaunchKernelGGL(CopyKeysKernel, dim3((total_len + 512 - 1) / 512),
                     dim3(512), 0, stream, origin_keys, total_keys, gpu_len,
                     slot_num, total_len);
  hipStreamSynchronize(stream);
#else
H
hutuxian 已提交
215 216 217
  CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>(
      origin_keys, total_keys, gpu_len, slot_num, total_len);
  cudaStreamSynchronize(stream);
218
#endif
H
hutuxian 已提交
219 220 221 222
}

void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
                             const std::vector<const float*>& grad_values,
S
ShenLiang 已提交
223
                             void* total_grad_values_gpu,
H
hutuxian 已提交
224
                             const std::vector<int64_t>& slot_lengths,
S
ShenLiang 已提交
225 226
                             const int hidden_size, const int expand_embed_dim,
                             const int64_t total_length, const int batch_size) {
H
hutuxian 已提交
227 228
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(
229
                        BOOST_GET_CONST(platform::CUDAPlace, place)))
H
hutuxian 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
      memory::AllocShared(place, grad_values.size() * sizeof(float*));
  auto buf_length =
      memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
  auto buf_slot_vector =
      memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int));

  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());

246 247 248 249 250 251 252 253
#ifdef PADDLE_WITH_HIP
  hipMemcpy(gpu_values, grad_values.data(), grad_values.size() * sizeof(float*),
            hipMemcpyHostToDevice);
  hipMemcpy(gpu_len, slot_lengths_lod.data(),
            slot_lengths.size() * sizeof(int64_t), hipMemcpyHostToDevice);
  hipMemcpy(d_slot_vector, slot_vector_.data(),
            slot_lengths_lod.size() * sizeof(int), hipMemcpyHostToDevice);
#else
H
hutuxian 已提交
254 255 256 257 258 259
  cudaMemcpy(gpu_values, grad_values.data(),
             grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len, slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector, slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
260
#endif
H
hutuxian 已提交
261

S
ShenLiang 已提交
262 263 264 265 266 267 268 269 270 271 272
#define EMBEDX_CASE(i, ...)                                                  \
  case i: {                                                                  \
    constexpr size_t EmbedxDim = i;                                          \
    switch (expand_embed_dim) {                                              \
      __VA_ARGS__                                                            \
      default:                                                               \
        PADDLE_THROW(platform::errors::InvalidArgument(                      \
            "Unsupport this expand embedding size [%d]", expand_embed_dim)); \
    }                                                                        \
  } break

273 274 275 276 277 278 279 280 281 282 283 284
#ifdef PADDLE_WITH_HIP
#define EXPAND_EMBED_PUSH_CASE(i, ...)                                       \
  case i: {                                                                  \
    constexpr size_t ExpandDim = i;                                          \
    hipLaunchKernelGGL(PushCopy<EmbedxDim, ExpandDim>,                       \
        dim3(total_length + 512 - 1) / 512), dim3(512), 0, stream,           \
        reinterpret_cast<boxps::FeaturePushValueGpu<EmbedxDim, ExpandDim>*>( \
            total_grad_values_gpu),                                          \
        gpu_values, gpu_len, hidden_size, expand_embed_dim,                  \
        slot_lengths.size(), total_length, batch_size, d_slot_vector);       \
  } break
#else
S
ShenLiang 已提交
285 286 287 288 289 290 291 292 293 294
#define EXPAND_EMBED_PUSH_CASE(i, ...)                                       \
  case i: {                                                                  \
    constexpr size_t ExpandDim = i;                                          \
    PushCopy<EmbedxDim,                                                      \
             ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
        reinterpret_cast<boxps::FeaturePushValueGpu<EmbedxDim, ExpandDim>*>( \
            total_grad_values_gpu),                                          \
        gpu_values, gpu_len, hidden_size, expand_embed_dim,                  \
        slot_lengths.size(), total_length, batch_size, d_slot_vector);       \
  } break
295
#endif
S
ShenLiang 已提交
296 297 298 299 300 301 302 303 304 305

  switch (hidden_size - 3) {
    EMBEDX_CASE(8, EXPAND_EMBED_PUSH_CASE(0); EXPAND_EMBED_PUSH_CASE(8);
                EXPAND_EMBED_PUSH_CASE(64););
    EMBEDX_CASE(16, EXPAND_EMBED_PUSH_CASE(0););
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupport this embedding size [%d]", hidden_size - 3));
  }

H
hutuxian 已提交
306
  cudaStreamSynchronize(stream);
S
ShenLiang 已提交
307 308
#undef EXPAND_EMBED_PUSH_CASE
#undef EMBEDX_CASE
H
hutuxian 已提交
309
}
S
ShenLiang 已提交
310

H
hutuxian 已提交
311 312 313
}  // end namespace framework
}  // end namespace paddle
#endif