top_k_op.cu 20.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
武毅 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
武毅 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
武毅 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
武毅 已提交
14

15
#include <cstdio>
16
#include "cub/cub.cuh"
Y
Yi Wang 已提交
17
#include "paddle/fluid/framework/op_registry.h"
18
#include "paddle/fluid/operators/top_k_op.h"
C
chengduoZH 已提交
19
#include "paddle/fluid/platform/cuda_device_function.h"
W
Wu Yi 已提交
20
#include "paddle/fluid/platform/float16.h"
21 22 23 24 25 26 27 28
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
    : BaseTraits<FLOATING_POINT, true, false, uint16_t,
                 paddle::platform::float16> {};
}  // namespace cub

武毅 已提交
29 30 31 32 33 34 35 36
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
struct Pair {
  __device__ __forceinline__ Pair() {}
F
fengjiayi 已提交
37
  __device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
武毅 已提交
38

F
fengjiayi 已提交
39
  __device__ __forceinline__ void set(T value, int64_t id) {
武毅 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    v = value;
    id = id;
  }

  __device__ __forceinline__ void operator=(const Pair<T>& in) {
    v = in.v;
    id = in.id;
  }

  __device__ __forceinline__ bool operator<(const T value) const {
    return (v < value);
  }

  __device__ __forceinline__ bool operator<(const Pair<T>& in) const {
    return (v < in.v) || ((v == in.v) && (id > in.id));
  }

  __device__ __forceinline__ bool operator>(const Pair<T>& in) const {
    return (v > in.v) || ((v == in.v) && (id < in.id));
  }

  T v;
F
fengjiayi 已提交
62
  int64_t id;
武毅 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
};

template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p,
                                      int beam_size) {
  for (int k = beam_size - 2; k >= 0; k--) {
    if (topk[k] < p) {
      topk[k + 1] = topk[k];
    } else {
      topk[k + 1] = p;
      return;
    }
  }
  topk[0] = p;
}

template <typename T, int beam_size>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p) {
  for (int k = beam_size - 2; k >= 0; k--) {
    if (topk[k] < p) {
      topk[k + 1] = topk[k];
    } else {
      topk[k + 1] = p;
      return;
    }
  }
  topk[0] = p;
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
                                        int dim, int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < src[idx]) {
      Pair<T> tmp(src[idx], idx);
      AddTo<T>(topk, tmp, beam_size);
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
                                        int dim, const Pair<T>& max,
                                        int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < src[idx]) {
      Pair<T> tmp(src[idx], idx);
      if (tmp < max) {
        AddTo<T>(topk, tmp, beam_size);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
                                        int idx, int dim, int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < val[idx]) {
      Pair<T> tmp(val[idx], col[idx]);
      AddTo<T>(topk, tmp, beam_size);
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
                                        int idx, int dim, const Pair<T>& max,
                                        int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < val[idx]) {
      Pair<T> tmp(val[idx], col[idx]);
      if (tmp < max) {
        AddTo<T>(topk, tmp, beam_size);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int MaxLength, int BlockSize>
147
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
武毅 已提交
148
                                              int beam_size, const T* src,
149 150
                                              bool* firstStep, bool* is_empty,
                                              Pair<T>* max, int dim,
武毅 已提交
151
                                              const int tid) {
152 153 154 155
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
武毅 已提交
156 157 158
      GetTopK<T, BlockSize>(topk, src, tid, dim, length);
    } else {
      for (int k = 0; k < MaxLength; k++) {
159 160
        if (k < MaxLength - (*beam)) {
          topk[k] = topk[k + *beam];
武毅 已提交
161
        } else {
W
Wu Yi 已提交
162
          topk[k].set(-static_cast<T>(INFINITY), -1);
武毅 已提交
163 164
        }
      }
165 166
      if (!(*is_empty)) {
        GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
武毅 已提交
167 168 169 170
                              length);
      }
    }

171
    *max = topk[MaxLength - 1];
W
Wu Yi 已提交
172
    if ((*max).v == -static_cast<T>(1)) *is_empty = true;
173
    *beam = 0;
武毅 已提交
174 175 176 177
  }
}

template <typename T, int MaxLength, int BlockSize>
178
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
武毅 已提交
179
                                              int beam_size, const T* val,
180 181
                                              int* col, bool* firstStep,
                                              bool* is_empty, Pair<T>* max,
武毅 已提交
182
                                              int dim, const int tid) {
183 184 185 186
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
武毅 已提交
187 188 189
      GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
    } else {
      for (int k = 0; k < MaxLength; k++) {
190 191
        if (k < MaxLength - *beam) {
          topk[k] = topk[k + *beam];
武毅 已提交
192
        } else {
W
Wu Yi 已提交
193
          topk[k].set(-static_cast<T>(INFINITY), -1);
武毅 已提交
194 195
        }
      }
196 197
      if (!(*is_empty)) {
        GetTopK<T, BlockSize>(topk + MaxLength - *beam, val, col, tid, dim, max,
武毅 已提交
198 199 200 201
                              length);
      }
    }

202 203 204
    *max = topk[MaxLength - 1];
    if ((*max).v == -1) *is_empty = true;
    *beam = 0;
武毅 已提交
205 206 207 208 209 210
  }
}

template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
                                            Pair<T> topk[], T** topVal,
211
                                            int64_t** topIds, int* beam, int* k,
武毅 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
                                            const int tid, const int warp) {
  while (true) {
    __syncthreads();
    if (tid < BlockSize / 2) {
      if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
        maxid[tid] = tid + BlockSize / 2;
      } else {
        maxid[tid] = tid;
      }
    }
    __syncthreads();
    for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
      if (tid < stride) {
        if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
          maxid[tid] = maxid[tid + stride];
        }
      }
      __syncthreads();
    }
    __syncthreads();

    if (tid == 0) {
      **topVal = sh_topk[maxid[0]].v;
      **topIds = sh_topk[maxid[0]].id;
      (*topVal)++;
      (*topIds)++;
    }
239 240
    if (tid == maxid[0]) (*beam)++;
    if (--(*k) == 0) break;
武毅 已提交
241 242 243
    __syncthreads();

    if (tid == maxid[0]) {
244 245
      if (*beam < MaxLength) {
        sh_topk[tid] = topk[*beam];
武毅 已提交
246 247
      }
    }
C
chengduoZH 已提交
248
    // NOTE(zcd): temporary solution
C
chengduoZH 已提交
249 250 251
    unsigned mask = 0u;
    CREATE_SHFL_MASK(mask, true);

武毅 已提交
252
    if (maxid[0] / 32 == warp) {
C
chengduoZH 已提交
253 254
      if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
          MaxLength)
C
chengduoZH 已提交
255
        break;
武毅 已提交
256 257 258 259 260 261 262 263 264 265 266 267
    }
  }
}

/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top MaxLength value;
 * 2. merge to sh_topk, block reduce and get max value;
 * 3. go to the second setp, until one thread's topk value is null;
 * 4. go to the first setp, until get the topk value.
 */
268

武毅 已提交
269
template <typename T, int MaxLength, int BlockSize>
F
fengjiayi 已提交
270
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
271 272
                             const T* src, int lds, int dim, int k,
                             int grid_dim, int num) {
武毅 已提交
273 274 275 276
  __shared__ Pair<T> sh_topk[BlockSize];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;

277 278
  const int bid = blockIdx.x;
  for (int i = bid; i < num; i += grid_dim) {
Q
qingqing01 已提交
279 280 281 282
    int top_num = k;
    __shared__ int maxid[BlockSize / 2];
    T* out = output + i * output_stride;
    int64_t* inds = indices + i * k;
283 284 285 286 287 288
    Pair<T> topk[MaxLength];
    int beam = MaxLength;
    Pair<T> max;
    bool is_empty = false;
    bool firststep = true;

Q
qingqing01 已提交
289
    for (int j = 0; j < MaxLength; j++) {
W
Wu Yi 已提交
290
      topk[j].set(-static_cast<T>(INFINITY), -1);
291
    }
Q
qingqing01 已提交
292
    while (top_num) {
293 294
      ThreadGetTopK<T, MaxLength, BlockSize>(
          topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
武毅 已提交
295

296
      sh_topk[tid] = topk[0];
Q
qingqing01 已提交
297 298
      BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
                                           &beam, &top_num, tid, warp);
299
    }
武毅 已提交
300
  }
301 302
}

303 304 305 306 307 308 309 310 311 312 313 314 315 316
template <typename T, int MaxLength, int BlockSize>
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
                           size_t rows, size_t cols, size_t k) {
  for (size_t i = 0; i < rows; ++i) {
    for (size_t j = 0; j < cols; ++j) {
      x_grad[i * cols + j] = 0;
    }
    for (size_t j = 0; j < k; ++j) {
      size_t idx = indices[i * k + j];
      x_grad[i * cols + idx] = out_grad[i * k + j];
    }
  }
}

317 318 319 320 321 322 323 324 325
inline static int GetDesiredBlockDim(int dim) {
  if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
武毅 已提交
326 327 328
  }
}

329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
// Iter for move to next row
struct SegmentOffsetIter {
  EIGEN_DEVICE_FUNC
  explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
    return idx * num_cols_;
  }

  int num_cols_;
};

// Iter using into a column
struct ColumnIndexIter {
  explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
      const Eigen::array<int, 1>& ix) const {
    return ix[0] % num_cols_;
  }

  int num_cols_;
};

353 354
__global__ void InitIndex(int64_t* indices, int64_t num_rows,
                          int64_t num_cols) {
355 356 357
  int col_id = threadIdx.x;
  int row_id = blockIdx.x;

358 359
  for (int64_t j = row_id; j < num_rows; j += gridDim.x) {
    for (int64_t i = col_id; i < num_cols; i += blockDim.x) {
360 361 362 363 364 365 366
      indices[j * num_cols + i] = i;
    }
  }
}

template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx,
367 368 369
              const framework::Tensor* input_tensor, const int64_t num_cols,
              const int64_t num_rows, const int k,
              framework::Tensor* out_tensor,
370 371 372 373
              framework::Tensor* indices_tensor) {
  auto cu_stream = ctx.stream();

  Tensor input_indices;
374
  const std::vector<int64_t> dims = {num_rows, num_cols};
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
  auto dim = framework::make_ddim(dims);
  input_indices.Resize(dim);
  // input_indices.Resize(num_rows*num_cols);
  input_indices.mutable_data<int64_t>(ctx.GetPlace());
  size_t temp_storage_bytes = -1;

  auto ComputeBlockSize = [](int col) {
    if (col > 512)
      return 1024;
    else if (col > 256 && col <= 512)
      return 512;
    else if (col > 128 && col <= 256)
      return 256;
    else if (col > 64 && col <= 128)
      return 128;
    else
      return 64;
  };

  int block_size = ComputeBlockSize(num_cols);

396
  unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
397
  // actually, int num_rows < max_grid_size
398 399 400
  unsigned int grid_size = num_rows < maxGridDimX
                               ? static_cast<unsigned int>(num_rows)
                               : maxGridDimX;
401 402 403 404 405
  // Init a index array
  InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
      input_indices.data<int64_t>(), num_rows, num_cols);

  // create iter for counting input
406
  cub::CountingInputIterator<int64_t> counting_iter(0);
407
  // segment_offset is used for move to next row
408 409
  cub::TransformInputIterator<int64_t, SegmentOffsetIter,
                              cub::CountingInputIterator<int64_t>>
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
      segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));

  T* sorted_values_ptr;
  int64_t* sorted_indices_ptr;

  Tensor temp_values;
  Tensor temp_indices;

  const T* input = input_tensor->data<T>();
  T* values = out_tensor->data<T>();
  int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());

  if (k == num_cols) {
    // Doing a full sort.
    sorted_values_ptr = values;
    sorted_indices_ptr = indices;
  } else {
    temp_values.Resize(dim);
    temp_indices.Resize(dim);
    sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
    sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
  }

  // Get temp storage buffer size, maybe can allocate a fixed buffer to save
  // time.
  auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
      nullptr, temp_storage_bytes, input, sorted_values_ptr,
      input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
      num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
      cu_stream);
  if (err != cudaSuccess) {
    LOG(ERROR)
        << "TopKOP failed as could not launch "
           "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
           "temp_storage_bytes, status: "
        << cudaGetErrorString(err);
    return false;
  }
  Tensor temp_storage;
  temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);

  err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
      temp_storage.data<uint8_t>(), temp_storage_bytes, input,
      sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
      num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
      0, sizeof(T) * 8, cu_stream);
  if (err != cudaSuccess) {
    LOG(ERROR)
        << "TopKOP failed as could not launch "
           "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
           "temp_storage_bytes: "
        << temp_storage_bytes << ", status: " << cudaGetErrorString(err);
    return false;
  }
  auto& dev = *ctx.eigen_device();
  if (k < num_cols) {
    // copy sliced data to output.
    const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
    const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
    auto e_indices = EigenMatrix<int64_t>::From(*indices_tensor, dim);
    auto e_tmp_indices = EigenMatrix<int64_t>::From(temp_indices);

    std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
    auto dim = framework::make_ddim(odims);
    auto e_values = EigenMatrix<T>::From(*out_tensor, dim);
    auto e_tmp_values = EigenMatrix<T>::From(temp_values);

    e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes);
    e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes);
  }
  return true;
}

483 484 485 486 487 488 489 490 491 492 493 494
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

#define FIXED_BLOCK_DIM(...)                \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

495
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
496
class TopkOpCUDAKernel : public framework::OpKernel<T> {
武毅 已提交
497 498 499
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
500
                   "It must use CUDAPlace.");
武毅 已提交
501 502 503
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* indices = ctx.Output<Tensor>("Indices");
504
    int k = static_cast<int>(ctx.Attr<int>("k"));
武毅 已提交
505

W
whs 已提交
506 507 508 509 510 511 512 513 514 515 516
    auto* k_t = ctx.Input<Tensor>("K");
    if (k_t) {
      Tensor k_host;
      framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
      k = k_host.data<int>()[0];
      framework::DDim output_dims = output->dims();
      output_dims[output_dims.size() - 1] = k;
      output->Resize(output_dims);
      indices->Resize(output_dims);
    }

武毅 已提交
517 518 519 520
    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    // FIXME(typhoonzero): data is always converted to type T?

Q
qingqing01 已提交
521
    framework::DDim inputdims = input->dims();
522
    const int64_t input_height = framework::product(
Q
qingqing01 已提交
523
        framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
524
    const int64_t input_width = inputdims[inputdims.size() - 1];
525
    const auto& dev_ctx = ctx.cuda_device_context();
Q
qingqing01 已提交
526

527 528 529 530 531 532 533 534 535 536 537
    if ((input_width <= 1024 || k >= 128 || k == input_width)) {
      if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
                      indices)) {
        // Successed, return.
        return;
      } else {
        LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
                     "default topk kernel.";
      }
    }
    int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
武毅 已提交
538 539 540 541 542
    if (k > input_width) k = input_width;

    // NOTE: pass lds and dim same to input width.
    // NOTE: old matrix implementation of stride is different to eigen.
    // TODO(typhoonzero): refine this kernel.
543 544 545 546 547 548
    const int kMaxHeight = 2048;
    int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
    switch (GetDesiredBlockDim(input_width)) {
      FIXED_BLOCK_DIM(
          KeMatrixTopK<T, 5,
                       kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
Q
qingqing01 已提交
549 550
              output_data, k, indices_data, input_data, input_width,
              input_width, static_cast<int>(k), gridx, input_height));
551 552 553
      default:
        PADDLE_THROW("Error");
    }
武毅 已提交
554 555 556
  }
};

557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592
template <typename DeviceContext, typename T>
class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(context.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
    auto* x = context.Input<Tensor>("X");
    auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* indices = context.Input<Tensor>("Indices");
    auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));

    T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
    const T* out_grad_data = out_grad->data<T>();
    const int64_t* indices_data = indices->data<int64_t>();
    size_t k = indices->dims()[indices->dims().size() - 1];

    framework::DDim xdims = x->dims();
    const size_t row =
        framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
    const size_t col = xdims[xdims.size() - 1];
    const auto& dev_ctx = context.cuda_device_context();

    const int kMaxHeight = 2048;
    int gridx = row < kMaxHeight ? row : kMaxHeight;
    switch (GetDesiredBlockDim(col)) {
      FIXED_BLOCK_DIM(
          AssignGrad<T, 5,
                     kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
              x_grad_data, indices_data, out_grad_data, row, col, k));
      default:
        PADDLE_THROW(
            platform::errors::Unavailable("Error occurs when Assign Grad."));
    }
  }
};
593 594 595
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM

武毅 已提交
596 597 598
}  // namespace operators
}  // namespace paddle

W
Wu Yi 已提交
599
REGISTER_OP_CUDA_KERNEL(
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
    top_k,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
                                        float>,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
                                        double>,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
                                        int>,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
                                        int64_t>,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
                                        paddle::platform::float16>);

REGISTER_OP_CUDA_KERNEL(
    top_k_grad,
    paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
                                            float>,
    paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
                                            double>,
    paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
                                            int>,
    paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
                                            int64_t>,
    paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
                                            paddle::platform::float16>);