box_wrapper.cu 11.1 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_BOX_PS
#include <algorithm>
#include <ctime>
#include <memory>
#include <numeric>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/gpu_info.h"

namespace paddle {
namespace framework {
#define CUDA_KERNEL_LOOP(i, n)                                 \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

S
ShenLiang 已提交
30 31 32 33 34 35
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PullCopy(
    float** dest,
    const boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* src,
    const int64_t* len, int hidden, int expand_dim, int slot_num, int total_len,
    uint64_t** keys) {
H
hutuxian 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    if (*(keys[x] + y) == 0) {
      *(dest[x] + y * hidden) = 0;
      *(dest[x] + y * hidden + 1) = 0;
      *(dest[x] + y * hidden + 2) = 0;
    } else {
      *(dest[x] + y * hidden) = (src + i)->show;
      *(dest[x] + y * hidden + 1) = (src + i)->clk;
      *(dest[x] + y * hidden + 2) = (src + i)->embed_w;
    }
    if ((src + i)->embedding_size == 0 || *(keys[x] + y) == 0) {
S
ShenLiang 已提交
58
      for (int j = 0; j < hidden - 3; j++) {
H
hutuxian 已提交
59 60 61
        *(dest[x] + y * hidden + 3 + j) = 0;
      }
    } else {
S
ShenLiang 已提交
62
      for (int j = 0; j < hidden - 3; j++) {
H
hutuxian 已提交
63 64 65
        *(dest[x] + y * hidden + 3 + j) = (src + i)->embedx[1 + j];
      }
    }
S
ShenLiang 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79
    // process embed_expand
    if (expand_dim > 0) {
      int z = x + slot_num;
      if ((src + i)->embed_expand_size[0] == 0 || *(keys[x] + y) == 0) {
        for (int j = 0; j < expand_dim; j++) {
          *(dest[z] + y * expand_dim + j) = 0;
        }
      } else {
        for (int j = 0; j < expand_dim; j++) {
          *(dest[z] + y * expand_dim + j) = (src + i)->embed_expand[1 + j];
        }
      }
    }
  }  // end kernel loop
H
hutuxian 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
}

__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
                               const int64_t* len, int slot_num,
                               int total_len) {
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    dest_total_keys[i] = src_keys[x][y];
  }
}

S
ShenLiang 已提交
101 102 103 104 105
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PushCopy(
    boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* dest, float** src,
    int64_t* len, int hidden, int expand_dim, int slot_num, int total_len,
    int bs, int* slot_vector) {
H
hutuxian 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
    (dest + i)->slot = slot_vector[x];
    (dest + i)->show = *(src[x] + y * hidden);
    (dest + i)->clk = *(src[x] + y * hidden + 1);
    (dest + i)->embed_g = *(src[x] + y * hidden + 2) * -1. * bs;
S
ShenLiang 已提交
122
    for (int j = 0; j < hidden - 3; j++) {
H
hutuxian 已提交
123 124
      (dest + i)->embedx_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
    }
S
ShenLiang 已提交
125 126 127 128 129 130 131
    if (expand_dim > 0) {
      int z = x + slot_num;
      for (int j = 0; j < expand_dim; j++) {
        (dest + i)->embed_expand_g[j] =
            *(src[z] + y * expand_dim + j) * -1. * bs;
      }
    }
H
hutuxian 已提交
132 133 134 135 136 137
  }
}

void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
                             uint64_t** gpu_keys,
                             const std::vector<float*>& values,
S
ShenLiang 已提交
138 139 140
                             void* total_values_gpu, const int64_t* gpu_len,
                             const int slot_num, const int hidden_size,
                             const int expand_embed_dim,
H
hutuxian 已提交
141 142 143
                             const int64_t total_length) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(
144
                        BOOST_GET_CONST(platform::CUDAPlace, place)))
H
hutuxian 已提交
145 146 147 148 149
                    ->stream();
  auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*));
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
  cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
S
ShenLiang 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
#define EMBEDX_CASE(i, ...)                                                  \
  case i: {                                                                  \
    constexpr size_t EmbedxDim = i;                                          \
    switch (expand_embed_dim) {                                              \
      __VA_ARGS__                                                            \
      default:                                                               \
        PADDLE_THROW(platform::errors::InvalidArgument(                      \
            "Unsupport this expand embedding size [%d]", expand_embed_dim)); \
    }                                                                        \
  } break

#define EXPAND_EMBED_PULL_CASE(i, ...)                                       \
  case i: {                                                                  \
    constexpr size_t ExpandDim = i;                                          \
    PullCopy<EmbedxDim,                                                      \
             ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
        gpu_values,                                                          \
        reinterpret_cast<boxps::FeatureValueGpu<EmbedxDim, ExpandDim>*>(     \
            total_values_gpu),                                               \
        gpu_len, hidden_size, expand_embed_dim, slot_num, total_length,      \
        gpu_keys);                                                           \
  } break
H
hutuxian 已提交
172

S
ShenLiang 已提交
173 174 175 176 177 178 179 180
  switch (hidden_size - 3) {
    EMBEDX_CASE(8, EXPAND_EMBED_PULL_CASE(0); EXPAND_EMBED_PULL_CASE(8);
                EXPAND_EMBED_PULL_CASE(64););
    EMBEDX_CASE(16, EXPAND_EMBED_PULL_CASE(0););
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupport this embedding size [%d]", hidden_size - 3));
  }
H
hutuxian 已提交
181
  cudaStreamSynchronize(stream);
S
ShenLiang 已提交
182 183
#undef EXPAND_EMBED_PULL_CASE
#undef EMBEDX_CASE
H
hutuxian 已提交
184 185 186 187 188 189 190
}

void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
                          uint64_t** origin_keys, uint64_t* total_keys,
                          const int64_t* gpu_len, int slot_num, int total_len) {
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(
191
                        BOOST_GET_CONST(platform::CUDAPlace, place)))
H
hutuxian 已提交
192 193 194 195 196 197 198 199
                    ->stream();
  CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>(
      origin_keys, total_keys, gpu_len, slot_num, total_len);
  cudaStreamSynchronize(stream);
}

void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
                             const std::vector<const float*>& grad_values,
S
ShenLiang 已提交
200
                             void* total_grad_values_gpu,
H
hutuxian 已提交
201
                             const std::vector<int64_t>& slot_lengths,
S
ShenLiang 已提交
202 203
                             const int hidden_size, const int expand_embed_dim,
                             const int64_t total_length, const int batch_size) {
H
hutuxian 已提交
204 205
  auto stream = dynamic_cast<platform::CUDADeviceContext*>(
                    platform::DeviceContextPool::Instance().Get(
206
                        BOOST_GET_CONST(platform::CUDAPlace, place)))
H
hutuxian 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
      memory::AllocShared(place, grad_values.size() * sizeof(float*));
  auto buf_length =
      memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
  auto buf_slot_vector =
      memory::AllocShared(place, slot_lengths_lod.size() * sizeof(int));

  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());

  cudaMemcpy(gpu_values, grad_values.data(),
             grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len, slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector, slot_vector_.data(),
             slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);

S
ShenLiang 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
#define EMBEDX_CASE(i, ...)                                                  \
  case i: {                                                                  \
    constexpr size_t EmbedxDim = i;                                          \
    switch (expand_embed_dim) {                                              \
      __VA_ARGS__                                                            \
      default:                                                               \
        PADDLE_THROW(platform::errors::InvalidArgument(                      \
            "Unsupport this expand embedding size [%d]", expand_embed_dim)); \
    }                                                                        \
  } break

#define EXPAND_EMBED_PUSH_CASE(i, ...)                                       \
  case i: {                                                                  \
    constexpr size_t ExpandDim = i;                                          \
    PushCopy<EmbedxDim,                                                      \
             ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
        reinterpret_cast<boxps::FeaturePushValueGpu<EmbedxDim, ExpandDim>*>( \
            total_grad_values_gpu),                                          \
        gpu_values, gpu_len, hidden_size, expand_embed_dim,                  \
        slot_lengths.size(), total_length, batch_size, d_slot_vector);       \
  } break

  switch (hidden_size - 3) {
    EMBEDX_CASE(8, EXPAND_EMBED_PUSH_CASE(0); EXPAND_EMBED_PUSH_CASE(8);
                EXPAND_EMBED_PUSH_CASE(64););
    EMBEDX_CASE(16, EXPAND_EMBED_PUSH_CASE(0););
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupport this embedding size [%d]", hidden_size - 3));
  }

H
hutuxian 已提交
261
  cudaStreamSynchronize(stream);
S
ShenLiang 已提交
262 263
#undef EXPAND_EMBED_PUSH_CASE
#undef EMBEDX_CASE
H
hutuxian 已提交
264
}
S
ShenLiang 已提交
265

H
hutuxian 已提交
266 267 268
}  // end namespace framework
}  // end namespace paddle
#endif