feature_value.cu 17.4 KB
Newer Older
D
danleifeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
  http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
16
#include "paddle/phi/backends/gpu/gpu_primitives.h"
D
danleifeng 已提交
17 18 19 20

namespace paddle {
namespace framework {

21
const int CUDA_NUM_THREADS = phi::PADDLE_CUDA_NUM_THREADS;
D
danleifeng 已提交
22 23 24 25
#define GET_BLOCK(N) ((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS)
#define CUDA_BLOCK(N) GET_BLOCK(N), CUDA_NUM_THREADS, 0

template <typename GPUAccessor>
D
danleifeng 已提交
26 27 28 29 30 31 32 33
__global__ void PullCopy(float** dest,
                         const float* src,
                         const int64_t* len,
                         int slot_num,
                         int total_len,
                         uint64_t** keys,
                         uint64_t max_val_size,
                         int* gpu_dim,
D
danleifeng 已提交
34
                         GPUAccessor gpu_accessor) {
D
danleifeng 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[x - 1] : 0);
    float* feature_value_ptr =
48
        (float*)((char*)src + uint64_t(i) * uint64_t(max_val_size));  // NOLINT
D
danleifeng 已提交
49
    int mf_dim = gpu_dim[x] - 3;
D
danleifeng 已提交
50
    gpu_accessor.Select(
D
danleifeng 已提交
51 52 53 54
        dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim);
  }
}

D
danleifeng 已提交
55 56 57 58 59 60 61 62
template <typename TAccess>
__global__ void PullDedupCopy(const size_t N,
                              const uint64_t* total_keys,
                              float** dest,
                              const float* src,
                              const int64_t* slot_lens,
                              uint64_t max_val_size,
                              const int* slot_dims,
L
lxsbupt 已提交
63
                              const size_t hidden,
D
danleifeng 已提交
64 65 66
                              const int* key2slot,
                              const uint32_t* restore_idx,
                              TAccess accessor) {
L
lxsbupt 已提交
67
  CUDA_KERNEL_LOOP_TYPE(idx, N, size_t) {
D
danleifeng 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80
    int i = idx / hidden;
    int off = idx % hidden;

    int x = key2slot[i];
    int y = i - slot_lens[x];

    float* dest_ptr = dest[x] + y * hidden;
    // 0 key fill zero
    if (total_keys[i] == 0) {
      *(dest_ptr + off) = 0;
      return;
    }

81
    float* src_ptr = (float*)((char*)src + uint64_t(restore_idx[i]) *  // NOLINT
D
danleifeng 已提交
82 83 84 85 86 87 88 89 90 91 92 93
                                               uint64_t(max_val_size));
    switch (off) {
      case 0:
        *(dest_ptr + off) = src_ptr[accessor.ShowIndex()];
        break;
      case 1:
        *(dest_ptr + off) = src_ptr[accessor.ClickIndex()];
        break;
      case 2:
        *(dest_ptr + off) = src_ptr[accessor.EmbedWIndex()];
        break;
      default:
L
lxsbupt 已提交
94 95
        int embedx_id = off - 3;
        if (embedx_id >= static_cast<int>(src_ptr[accessor.MfSizeIndex()])) {
D
danleifeng 已提交
96 97
          *(dest_ptr + off) = 0;
        } else {
L
lxsbupt 已提交
98
          *(dest_ptr + off) = src_ptr[accessor.EmbedxWIndex() + embedx_id];
D
danleifeng 已提交
99 100 101 102 103 104 105
        }
        break;
    }
  }
}

template <typename GPUAccessor>
D
danleifeng 已提交
106 107 108 109 110 111 112 113 114
__global__ void PushCopyWithPool(float* dest,
                                 float** src,
                                 int64_t* len,
                                 int slot_num,
                                 uint64_t total_len,
                                 int bs,
                                 int* slot_vector,
                                 int* mf_dim_vector,
                                 size_t grad_value_size,
D
danleifeng 已提交
115
                                 GPUAccessor gpu_accessor) {
D
danleifeng 已提交
116 117 118 119 120 121 122 123 124 125 126 127
  CUDA_KERNEL_LOOP(i, total_len) {
    int low = 0;
    int high = slot_num - 1;
    while (low < high) {
      int mid = (low + high) / 2;
      if (i < len[mid])
        high = mid;
      else
        low = mid + 1;
    }
    int x = low;
    int y = i - (x ? len[low - 1] : 0);
128
    float* cur = (float*)((char*)dest + i * grad_value_size);  // NOLINT
D
danleifeng 已提交
129

130
    cur[gpu_accessor.common_push_value.SlotIndex()] =
L
lxsbupt 已提交
131
        static_cast<float>(slot_vector[x]);
D
danleifeng 已提交
132
    int mf_dim = mf_dim_vector[x];
L
lxsbupt 已提交
133 134
    cur[gpu_accessor.common_push_value.MfDimIndex()] =
        static_cast<float>(mf_dim);
D
danleifeng 已提交
135

D
danleifeng 已提交
136
    cur[gpu_accessor.common_push_value.ShowIndex()] =
D
danleifeng 已提交
137
        *(src[x] + y * (mf_dim + 3));
D
danleifeng 已提交
138
    cur[gpu_accessor.common_push_value.ClickIndex()] =
D
danleifeng 已提交
139
        *(src[x] + y * (mf_dim + 3) + 1);
D
danleifeng 已提交
140
    cur[gpu_accessor.common_push_value.EmbedGIndex()] =
D
danleifeng 已提交
141 142
        *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs;
    for (int j = 0; j < mf_dim; j++) {
D
danleifeng 已提交
143
      cur[gpu_accessor.common_push_value.EmbedxGIndex() + j] =
D
danleifeng 已提交
144 145 146 147 148
          *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs;
    }
  }
}

D
danleifeng 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162
template <typename TAccess>
__global__ void PushMergeCopyAtomic(const size_t N,
                                    const uint64_t* total_keys,
                                    float* dest,
                                    float** src,
                                    const int hidden,
                                    const int bs,
                                    const int* slot_vector,
                                    const int* slot_dims,
                                    const int64_t* slot_lens,
                                    const int* key2slot,
                                    const uint32_t* d_restore_idx,
                                    size_t grad_value_size,
                                    TAccess accessor) {
L
lxsbupt 已提交
163
  CUDA_KERNEL_LOOP_TYPE(idx, N, size_t) {
D
danleifeng 已提交
164 165 166 167 168 169 170 171 172 173 174
    int i = idx / hidden;
    int off = idx % hidden;
    // filter 0 keys
    if (total_keys[i] == 0) {
      return;
    }

    int x = key2slot[i];
    int y = i - slot_lens[x];

    const float* ptr = src[x] + y * hidden;
175 176
    float* cur =
        (float*)((char*)dest + d_restore_idx[i] * grad_value_size);  // NOLINT
D
danleifeng 已提交
177 178 179
    int mf_dim = slot_dims[x] - 3;
    switch (off) {
      case 0:
L
lxsbupt 已提交
180 181
        cur[accessor.SlotIndex()] = static_cast<float>(slot_vector[x]);
        cur[accessor.MfDimIndex()] = static_cast<float>(mf_dim);
182
        phi::CudaAtomicAdd(&cur[accessor.ShowIndex()], *(ptr + off));
D
danleifeng 已提交
183 184
        break;
      case 1:
185
        phi::CudaAtomicAdd(&cur[accessor.ClickIndex()], *(ptr + off));
D
danleifeng 已提交
186 187
        break;
      case 2:
188 189
        phi::CudaAtomicAdd(&cur[accessor.EmbedGIndex()],
                           *(ptr + off) * -1. * bs);
D
danleifeng 已提交
190 191 192
        break;
      default:
        int embedx_idx = off - 3;
L
lxsbupt 已提交
193 194 195
        if (embedx_idx < mf_dim) {
          phi::CudaAtomicAdd(&cur[accessor.EmbedxGIndex() + embedx_idx],
                             *(ptr + off) * -1. * bs);
D
danleifeng 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        }
        break;
    }
  }
}

#define SUM_GRAD_VALUE                                             \
  for (uint32_t j = 0; j < count; ++j) {                           \
    const uint32_t& pos = d_sort_idx[start + j];                   \
    const int& x = key2slot[pos];                                  \
    y = pos - slot_lens[x];                                        \
    val += *(reinterpret_cast<float*>(src[x] + y * hidden + off)); \
  }

template <typename TAccess>
__global__ void PushMergeCopy(const size_t N,
                              const uint64_t* total_keys,
                              float* dest,
                              float** src,
                              const int hidden,
                              const int bs,
                              const int* slot_vector,
                              const int* slot_dims,
                              const int64_t* slot_lens,
                              const int* key2slot,
                              const uint32_t* d_sort_idx,
                              const uint32_t* d_sort_offset,
                              const uint32_t* d_sort_cnt,
                              size_t grad_value_size,
                              TAccess accessor) {
L
lxsbupt 已提交
226
  CUDA_KERNEL_LOOP_TYPE(idx, N, size_t) {
D
danleifeng 已提交
227 228 229
    int i = idx / hidden;
    int off = idx % hidden;
    // filter 0 keys
230
    float* cur = (float*)((char*)dest + i * grad_value_size);  // NOLINT
D
danleifeng 已提交
231 232 233 234

    if (total_keys[i] == 0) {
      switch (off) {
        case 0:
L
lxsbupt 已提交
235 236
          cur[accessor.SlotIndex()] = static_cast<float>(0);
          cur[accessor.MfDimIndex()] = static_cast<float>(0);
D
danleifeng 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
          cur[accessor.ShowIndex()] = 0.0;
          break;
        case 1:
          cur[accessor.ClickIndex()] = 0.0;
          break;
        case 2:
          cur[accessor.EmbedGIndex()] = 0.0;
          break;
        default:
          cur[accessor.EmbedxGIndex() + off - 3] = 0.0;
          break;
      }
      return;
    }

    const uint32_t& start = d_sort_offset[i];
    const uint32_t& count = d_sort_cnt[i];
    const uint32_t& pos = d_sort_idx[start];

    const int& x = key2slot[pos];
    int y = pos - slot_lens[x];
    int mf_dim = slot_dims[x] - 3;

    double val = 0.0;

    switch (off) {
      case 0:
L
lxsbupt 已提交
264 265
        cur[accessor.SlotIndex()] = static_cast<float>(slot_vector[x]);
        cur[accessor.MfDimIndex()] = static_cast<float>(mf_dim);
D
danleifeng 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278
        SUM_GRAD_VALUE
        cur[accessor.ShowIndex()] = val;
        break;
      case 1:
        SUM_GRAD_VALUE
        cur[accessor.ClickIndex()] = val;
        break;
      case 2:
        SUM_GRAD_VALUE
        cur[accessor.EmbedGIndex()] = val * -1. * bs;
        break;
      default:
        int embedx_idx = off - 3;
L
lxsbupt 已提交
279 280 281 282
        if (embedx_idx < mf_dim) {
          SUM_GRAD_VALUE
          cur[accessor.EmbedxGIndex() + embedx_idx] = val * -1. * bs;
        } else {
D
danleifeng 已提交
283 284 285 286 287 288 289
          cur[accessor.EmbedxGIndex() + embedx_idx] = 0.0;
        }
        break;
    }
  }
}

D
danleifeng 已提交
290 291 292 293 294 295 296 297 298 299 300 301
template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPullImpl(
    const paddle::platform::Place& place,
    uint64_t** gpu_keys,
    const std::vector<float*>& values,
    const float* total_values_gpu,
    const int64_t* gpu_len,
    const int slot_num,
    const int hidden_size,
    const int64_t total_length,
    int* gpu_dim,
    int feature_value_size) {
L
Leo Chen 已提交
302
  auto stream = dynamic_cast<phi::GPUContext*>(
D
danleifeng 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
                    paddle::platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  auto buf_value = memory::Alloc(place, values.size() * sizeof(float*));
  float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
  cudaMemcpy(gpu_values,
             values.data(),
             values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
  PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
      gpu_values,
      total_values_gpu,
      gpu_len,
      slot_num,
      total_length,
      gpu_keys,
      feature_value_size,
      gpu_dim,
      gpu_accessor_);
  cudaStreamSynchronize(stream);
}

template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPushImpl(
    const paddle::platform::Place& place,
    const std::vector<const float*>& grad_values,
    float* total_grad_values_gpu,
    const std::vector<int64_t>& slot_lengths,
    const uint64_t total_length,
    const int batch_size,
    size_t grad_value_size,
333 334
    std::vector<int>& slot_vector,           // NOLINT
    std::vector<int>& slot_mf_dim_vector) {  // NOLINT
L
Leo Chen 已提交
335
  auto stream = dynamic_cast<phi::GPUContext*>(
D
danleifeng 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
                    paddle::platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  auto slot_lengths_lod = slot_lengths;
  for (int i = 1; i < slot_lengths_lod.size(); i++) {
    slot_lengths_lod[i] += slot_lengths_lod[i - 1];
  }
  auto buf_grad_value =
      memory::Alloc(place, grad_values.size() * sizeof(float*));
  auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t));
  auto buf_slot_vector =
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
  auto buf_mf_dim_vector =
      memory::Alloc(place, slot_lengths_lod.size() * sizeof(int));
  float** gpu_values = reinterpret_cast<float**>(buf_grad_value->ptr());
  int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
  int* d_slot_vector = reinterpret_cast<int*>(buf_slot_vector->ptr());
  int* d_mf_dim_vector = reinterpret_cast<int*>(buf_mf_dim_vector->ptr());
  cudaMemcpy(gpu_values,
             grad_values.data(),
             grad_values.size() * sizeof(float*),
             cudaMemcpyHostToDevice);
  cudaMemcpy(gpu_len,
             slot_lengths_lod.data(),
             slot_lengths.size() * sizeof(int64_t),
             cudaMemcpyHostToDevice);
  cudaMemcpy(d_slot_vector,
             slot_vector.data(),
             slot_lengths_lod.size() * sizeof(int),
             cudaMemcpyHostToDevice);
  cudaMemcpy(d_mf_dim_vector,
             slot_mf_dim_vector.data(),
             slot_lengths_lod.size() * sizeof(int),
             cudaMemcpyHostToDevice);
  PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
      total_grad_values_gpu,
      gpu_values,
      gpu_len,
      slot_lengths.size(),
      total_length,
      batch_size,
      d_slot_vector,
      d_mf_dim_vector,
      grad_value_size,
      gpu_accessor_);
  cudaStreamSynchronize(stream);
}

D
danleifeng 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395
template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPullDedupImpl(
    const paddle::platform::Place& place,
    const uint64_t* total_keys,
    float** gpu_values,
    const float* total_values_gpu,
    const int64_t* slot_lens,
    const int* key2slot,
    const int hidden_size,
    const int64_t total_length,
    const int* slot_dims,
    const uint32_t* gpu_restore_idx,
    int pull_value_size) {
396
  auto stream = dynamic_cast<phi::GPUContext*>(
D
danleifeng 已提交
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
                    paddle::platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  size_t N = total_length * hidden_size;
  PullDedupCopy<<<CUDA_BLOCK(N), stream>>>(N,
                                           total_keys,
                                           gpu_values,
                                           total_values_gpu,
                                           slot_lens,
                                           pull_value_size,
                                           slot_dims,
                                           hidden_size,
                                           key2slot,
                                           gpu_restore_idx,
                                           gpu_accessor_.common_pull_value);
  cudaStreamSynchronize(stream);
}

template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPushDedupImpl(
    const paddle::platform::Place& place,
    const uint64_t* total_keys,
    float** grad_values,
    float* total_grad_values_gpu,
    const int* slots,
    const int64_t* slot_lens,
    const int hidden_size,
    const int64_t total_length,
    const int64_t dedup_length,
    const int batch_size,
    const int* slot_dims,
    const int* key2slot,
    const uint32_t* d_restore_idx,
    const size_t grad_value_size) {
430
  auto stream = dynamic_cast<phi::GPUContext*>(
D
danleifeng 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
                    paddle::platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  cudaMemsetAsync(
      total_grad_values_gpu, 0, dedup_length * grad_value_size, stream);
  size_t N = total_length * hidden_size;
  PushMergeCopyAtomic<<<CUDA_BLOCK(N), stream>>>(
      N,
      total_keys,
      total_grad_values_gpu,
      grad_values,
      hidden_size,
      batch_size,
      slots,
      slot_dims,
      slot_lens,
      key2slot,
      d_restore_idx,
      grad_value_size,
      gpu_accessor_.common_push_value);

  cudaStreamSynchronize(stream);
}

template <typename GPUAccessor>
void AccessorWrapper<GPUAccessor>::CopyForPushDedupImpl(
    const paddle::platform::Place& place,
    const uint64_t* total_keys,
    float** grad_values,
    float* total_grad_values_gpu,
    const int* slots,
    const int64_t* slot_lens,
    const int hidden_size,
    const int64_t total_length,
    const int64_t dedup_length,
    const int batch_size,
    const int* slot_dims,
    const int* key2slot,
    const uint32_t* gpu_sort_idx,
    const uint32_t* gpu_sort_offset,
    const uint32_t* gpu_sort_lens,
    const size_t grad_value_size) {
472
  auto stream = dynamic_cast<phi::GPUContext*>(
D
danleifeng 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
                    paddle::platform::DeviceContextPool::Instance().Get(place))
                    ->stream();
  // merge all grad to one
  size_t N = dedup_length * hidden_size;
  PushMergeCopy<<<CUDA_BLOCK(N), stream>>>(N,
                                           total_keys,
                                           total_grad_values_gpu,
                                           grad_values,
                                           hidden_size,
                                           batch_size,
                                           slots,
                                           slot_dims,
                                           slot_lens,
                                           key2slot,
                                           gpu_sort_idx,
                                           gpu_sort_offset,
                                           gpu_sort_lens,
                                           grad_value_size,
                                           gpu_accessor_.common_push_value);
  cudaStreamSynchronize(stream);
}

D
danleifeng 已提交
495 496 497 498
#ifdef PADDLE_WITH_PSCORE
template class AccessorWrapper<CommonFeatureValueAccessor>;
#endif

P
pangengzheng 已提交
499 500 501 502
#ifdef PADDLE_WITH_PSLIB
template class AccessorWrapper<CommonFeatureValueAccessor>;
#endif

D
danleifeng 已提交
503 504 505
}  // namespace framework
}  // namespace paddle
#endif