top_k_op.cu 17.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
武毅 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
武毅 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
武毅 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
武毅 已提交
14

15
#include "cub/cub.cuh"
Y
Yi Wang 已提交
16
#include "paddle/fluid/framework/op_registry.h"
17
#include "paddle/fluid/operators/top_k_op.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/platform/cuda_device_function.h"
W
Wu Yi 已提交
19
#include "paddle/fluid/platform/float16.h"
武毅 已提交
20

21 22 23 24 25 26 27 28
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
    : BaseTraits<FLOATING_POINT, true, false, uint16_t,
                 paddle::platform::float16> {};
}  // namespace cub

武毅 已提交
29 30 31 32 33 34 35 36
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
struct Pair {
  __device__ __forceinline__ Pair() {}
F
fengjiayi 已提交
37
  __device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
武毅 已提交
38

F
fengjiayi 已提交
39
  __device__ __forceinline__ void set(T value, int64_t id) {
武毅 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    v = value;
    id = id;
  }

  __device__ __forceinline__ void operator=(const Pair<T>& in) {
    v = in.v;
    id = in.id;
  }

  __device__ __forceinline__ bool operator<(const T value) const {
    return (v < value);
  }

  __device__ __forceinline__ bool operator<(const Pair<T>& in) const {
    return (v < in.v) || ((v == in.v) && (id > in.id));
  }

  __device__ __forceinline__ bool operator>(const Pair<T>& in) const {
    return (v > in.v) || ((v == in.v) && (id < in.id));
  }

  T v;
F
fengjiayi 已提交
62
  int64_t id;
武毅 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
};

template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p,
                                      int beam_size) {
  for (int k = beam_size - 2; k >= 0; k--) {
    if (topk[k] < p) {
      topk[k + 1] = topk[k];
    } else {
      topk[k + 1] = p;
      return;
    }
  }
  topk[0] = p;
}

template <typename T, int beam_size>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p) {
  for (int k = beam_size - 2; k >= 0; k--) {
    if (topk[k] < p) {
      topk[k + 1] = topk[k];
    } else {
      topk[k + 1] = p;
      return;
    }
  }
  topk[0] = p;
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
                                        int dim, int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < src[idx]) {
      Pair<T> tmp(src[idx], idx);
      AddTo<T>(topk, tmp, beam_size);
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
                                        int dim, const Pair<T>& max,
                                        int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < src[idx]) {
      Pair<T> tmp(src[idx], idx);
      if (tmp < max) {
        AddTo<T>(topk, tmp, beam_size);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
                                        int idx, int dim, int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < val[idx]) {
      Pair<T> tmp(val[idx], col[idx]);
      AddTo<T>(topk, tmp, beam_size);
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
                                        int idx, int dim, const Pair<T>& max,
                                        int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < val[idx]) {
      Pair<T> tmp(val[idx], col[idx]);
      if (tmp < max) {
        AddTo<T>(topk, tmp, beam_size);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int MaxLength, int BlockSize>
147
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
武毅 已提交
148
                                              int beam_size, const T* src,
149 150
                                              bool* firstStep, bool* is_empty,
                                              Pair<T>* max, int dim,
武毅 已提交
151
                                              const int tid) {
152 153 154 155
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
武毅 已提交
156 157 158
      GetTopK<T, BlockSize>(topk, src, tid, dim, length);
    } else {
      for (int k = 0; k < MaxLength; k++) {
159 160
        if (k < MaxLength - (*beam)) {
          topk[k] = topk[k + *beam];
武毅 已提交
161
        } else {
W
Wu Yi 已提交
162
          topk[k].set(-static_cast<T>(INFINITY), -1);
武毅 已提交
163 164
        }
      }
165 166
      if (!(*is_empty)) {
        GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
武毅 已提交
167 168 169 170
                              length);
      }
    }

171
    *max = topk[MaxLength - 1];
W
Wu Yi 已提交
172
    if ((*max).v == -static_cast<T>(1)) *is_empty = true;
173
    *beam = 0;
武毅 已提交
174 175 176 177
  }
}

template <typename T, int MaxLength, int BlockSize>
178
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
武毅 已提交
179
                                              int beam_size, const T* val,
180 181
                                              int* col, bool* firstStep,
                                              bool* is_empty, Pair<T>* max,
武毅 已提交
182
                                              int dim, const int tid) {
183 184 185 186
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
武毅 已提交
187 188 189
      GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
    } else {
      for (int k = 0; k < MaxLength; k++) {
190 191
        if (k < MaxLength - *beam) {
          topk[k] = topk[k + *beam];
武毅 已提交
192
        } else {
W
Wu Yi 已提交
193
          topk[k].set(-static_cast<T>(INFINITY), -1);
武毅 已提交
194 195
        }
      }
196 197
      if (!(*is_empty)) {
        GetTopK<T, BlockSize>(topk + MaxLength - *beam, val, col, tid, dim, max,
武毅 已提交
198 199 200 201
                              length);
      }
    }

202 203 204
    *max = topk[MaxLength - 1];
    if ((*max).v == -1) *is_empty = true;
    *beam = 0;
武毅 已提交
205 206 207 208 209 210
  }
}

template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
                                            Pair<T> topk[], T** topVal,
211
                                            int64_t** topIds, int* beam, int* k,
武毅 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
                                            const int tid, const int warp) {
  while (true) {
    __syncthreads();
    if (tid < BlockSize / 2) {
      if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
        maxid[tid] = tid + BlockSize / 2;
      } else {
        maxid[tid] = tid;
      }
    }
    __syncthreads();
    for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
      if (tid < stride) {
        if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
          maxid[tid] = maxid[tid + stride];
        }
      }
      __syncthreads();
    }
    __syncthreads();

    if (tid == 0) {
      **topVal = sh_topk[maxid[0]].v;
      **topIds = sh_topk[maxid[0]].id;
      (*topVal)++;
      (*topIds)++;
    }
239 240
    if (tid == maxid[0]) (*beam)++;
    if (--(*k) == 0) break;
武毅 已提交
241 242 243
    __syncthreads();

    if (tid == maxid[0]) {
244 245
      if (*beam < MaxLength) {
        sh_topk[tid] = topk[*beam];
武毅 已提交
246 247
      }
    }
C
chengduoZH 已提交
248
    // NOTE(zcd): temporary solution
C
chengduoZH 已提交
249 250 251
    unsigned mask = 0u;
    CREATE_SHFL_MASK(mask, true);

武毅 已提交
252
    if (maxid[0] / 32 == warp) {
C
chengduoZH 已提交
253 254
      if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
          MaxLength)
C
chengduoZH 已提交
255
        break;
武毅 已提交
256 257 258 259 260 261 262 263 264 265 266 267
    }
  }
}

/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top MaxLength value;
 * 2. merge to sh_topk, block reduce and get max value;
 * 3. go to the second setp, until one thread's topk value is null;
 * 4. go to the first setp, until get the topk value.
 */
268

武毅 已提交
269
template <typename T, int MaxLength, int BlockSize>
F
fengjiayi 已提交
270
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
271 272
                             const T* src, int lds, int dim, int k,
                             int grid_dim, int num) {
武毅 已提交
273 274 275 276
  __shared__ Pair<T> sh_topk[BlockSize];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;

277 278
  const int bid = blockIdx.x;
  for (int i = bid; i < num; i += grid_dim) {
Q
qingqing01 已提交
279 280 281 282
    int top_num = k;
    __shared__ int maxid[BlockSize / 2];
    T* out = output + i * output_stride;
    int64_t* inds = indices + i * k;
283 284 285 286 287 288
    Pair<T> topk[MaxLength];
    int beam = MaxLength;
    Pair<T> max;
    bool is_empty = false;
    bool firststep = true;

Q
qingqing01 已提交
289
    for (int j = 0; j < MaxLength; j++) {
W
Wu Yi 已提交
290
      topk[j].set(-static_cast<T>(INFINITY), -1);
291
    }
Q
qingqing01 已提交
292
    while (top_num) {
293 294
      ThreadGetTopK<T, MaxLength, BlockSize>(
          topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
武毅 已提交
295

296
      sh_topk[tid] = topk[0];
Q
qingqing01 已提交
297 298
      BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
                                           &beam, &top_num, tid, warp);
299
    }
武毅 已提交
300
  }
301 302 303 304 305 306 307 308 309 310 311
}

inline static int GetDesiredBlockDim(int dim) {
  if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
武毅 已提交
312 313 314
  }
}

315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
// Iter for move to next row
struct SegmentOffsetIter {
  EIGEN_DEVICE_FUNC
  explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
    return idx * num_cols_;
  }

  int num_cols_;
};

// Iter using into a column
struct ColumnIndexIter {
  explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
      const Eigen::array<int, 1>& ix) const {
    return ix[0] % num_cols_;
  }

  int num_cols_;
};

__global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
  int col_id = threadIdx.x;
  int row_id = blockIdx.x;

  for (int j = row_id; j < num_rows; j += gridDim.x) {
    for (int i = col_id; i < num_cols; i += blockDim.x) {
      indices[j * num_cols + i] = i;
    }
  }
}

template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx,
              const framework::Tensor* input_tensor, const size_t num_cols,
              const size_t num_rows, size_t k, framework::Tensor* out_tensor,
              framework::Tensor* indices_tensor) {
  auto cu_stream = ctx.stream();

  Tensor input_indices;
  const std::vector<int64_t> dims = {static_cast<int64_t>(num_rows),
                                     static_cast<int64_t>(num_cols)};
  auto dim = framework::make_ddim(dims);
  input_indices.Resize(dim);
  // input_indices.Resize(num_rows*num_cols);
  input_indices.mutable_data<int64_t>(ctx.GetPlace());
  size_t temp_storage_bytes = -1;

  auto ComputeBlockSize = [](int col) {
    if (col > 512)
      return 1024;
    else if (col > 256 && col <= 512)
      return 512;
    else if (col > 128 && col <= 256)
      return 256;
    else if (col > 64 && col <= 128)
      return 128;
    else
      return 64;
  };

  int block_size = ComputeBlockSize(num_cols);

  int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
  // actually, int num_rows < max_grid_size
  int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
  // Init a index array
  InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
      input_indices.data<int64_t>(), num_rows, num_cols);

  // create iter for counting input
  cub::CountingInputIterator<int> counting_iter(0);
  // segment_offset is used for move to next row
  cub::TransformInputIterator<int, SegmentOffsetIter,
                              cub::CountingInputIterator<int>>
      segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));

  T* sorted_values_ptr;
  int64_t* sorted_indices_ptr;

  Tensor temp_values;
  Tensor temp_indices;

  const T* input = input_tensor->data<T>();
  T* values = out_tensor->data<T>();
  int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());

  if (k == num_cols) {
    // Doing a full sort.
    sorted_values_ptr = values;
    sorted_indices_ptr = indices;
  } else {
    temp_values.Resize(dim);
    temp_indices.Resize(dim);
    sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
    sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
  }

  // Get temp storage buffer size, maybe can allocate a fixed buffer to save
  // time.
  auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
      nullptr, temp_storage_bytes, input, sorted_values_ptr,
      input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
      num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
      cu_stream);
  if (err != cudaSuccess) {
    LOG(ERROR)
        << "TopKOP failed as could not launch "
           "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
           "temp_storage_bytes, status: "
        << cudaGetErrorString(err);
    return false;
  }
  Tensor temp_storage;
  temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);

  err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
      temp_storage.data<uint8_t>(), temp_storage_bytes, input,
      sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
      num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
      0, sizeof(T) * 8, cu_stream);
  if (err != cudaSuccess) {
    LOG(ERROR)
        << "TopKOP failed as could not launch "
           "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
           "temp_storage_bytes: "
        << temp_storage_bytes << ", status: " << cudaGetErrorString(err);
    return false;
  }
  auto& dev = *ctx.eigen_device();
  if (k < num_cols) {
    // copy sliced data to output.
    const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
    const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
    auto e_indices = EigenMatrix<int64_t>::From(*indices_tensor, dim);
    auto e_tmp_indices = EigenMatrix<int64_t>::From(temp_indices);

    std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
    auto dim = framework::make_ddim(odims);
    auto e_values = EigenMatrix<T>::From(*out_tensor, dim);
    auto e_tmp_values = EigenMatrix<T>::From(temp_values);

    e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes);
    e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes);
  }
  return true;
}

466 467 468 469 470 471 472 473 474 475 476 477
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

#define FIXED_BLOCK_DIM(...)                \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

武毅 已提交
478
template <typename T>
Y
Yu Yang 已提交
479
class TopkOpCUDAKernel : public framework::OpKernel<T> {
武毅 已提交
480 481 482
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
483
                   "It must use CUDAPlace.");
武毅 已提交
484 485 486 487 488
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* indices = ctx.Output<Tensor>("Indices");
    size_t k = static_cast<int>(ctx.Attr<int>("k"));

W
whs 已提交
489 490 491 492 493 494 495 496 497 498 499
    auto* k_t = ctx.Input<Tensor>("K");
    if (k_t) {
      Tensor k_host;
      framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
      k = k_host.data<int>()[0];
      framework::DDim output_dims = output->dims();
      output_dims[output_dims.size() - 1] = k;
      output->Resize(output_dims);
      indices->Resize(output_dims);
    }

武毅 已提交
500 501 502 503
    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    // FIXME(typhoonzero): data is always converted to type T?

Q
qingqing01 已提交
504 505 506 507
    framework::DDim inputdims = input->dims();
    const size_t input_height = framework::product(
        framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
    const size_t input_width = inputdims[inputdims.size() - 1];
508
    const auto& dev_ctx = ctx.cuda_device_context();
Q
qingqing01 已提交
509

510 511 512 513 514 515 516 517 518 519 520
    if ((input_width <= 1024 || k >= 128 || k == input_width)) {
      if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
                      indices)) {
        // Successed, return.
        return;
      } else {
        LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
                     "default topk kernel.";
      }
    }
    int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
武毅 已提交
521 522 523 524 525
    if (k > input_width) k = input_width;

    // NOTE: pass lds and dim same to input width.
    // NOTE: old matrix implementation of stride is different to eigen.
    // TODO(typhoonzero): refine this kernel.
526 527 528 529 530 531
    const int kMaxHeight = 2048;
    int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
    switch (GetDesiredBlockDim(input_width)) {
      FIXED_BLOCK_DIM(
          KeMatrixTopK<T, 5,
                       kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
Q
qingqing01 已提交
532 533
              output_data, k, indices_data, input_data, input_width,
              input_width, static_cast<int>(k), gridx, input_height));
534 535 536
      default:
        PADDLE_THROW("Error");
    }
武毅 已提交
537 538 539
  }
};

540 541 542
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM

武毅 已提交
543 544 545
}  // namespace operators
}  // namespace paddle

W
Wu Yi 已提交
546 547 548 549
REGISTER_OP_CUDA_KERNEL(
    top_k, paddle::operators::TopkOpCUDAKernel<float>,
    paddle::operators::TopkOpCUDAKernel<double>,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);