top_k_op.cu 11.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
武毅 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
武毅 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
武毅 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
武毅 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/top_k_op.h"
Y
Yi Wang 已提交
17
#include "paddle/fluid/platform/assert.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/platform/cuda_device_function.h"
19
#include "paddle/fluid/platform/float16.h"
武毅 已提交
20 21 22 23 24 25 26 27 28

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
struct Pair {
  __device__ __forceinline__ Pair() {}
F
fengjiayi 已提交
29
  __device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
武毅 已提交
30

F
fengjiayi 已提交
31
  __device__ __forceinline__ void set(T value, int64_t id) {
武毅 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    v = value;
    id = id;
  }

  __device__ __forceinline__ void operator=(const Pair<T>& in) {
    v = in.v;
    id = in.id;
  }

  __device__ __forceinline__ bool operator<(const T value) const {
    return (v < value);
  }

  __device__ __forceinline__ bool operator<(const Pair<T>& in) const {
    return (v < in.v) || ((v == in.v) && (id > in.id));
  }

  __device__ __forceinline__ bool operator>(const Pair<T>& in) const {
    return (v > in.v) || ((v == in.v) && (id < in.id));
  }

  T v;
F
fengjiayi 已提交
54
  int64_t id;
武毅 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
};

template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p,
                                      int beam_size) {
  for (int k = beam_size - 2; k >= 0; k--) {
    if (topk[k] < p) {
      topk[k + 1] = topk[k];
    } else {
      topk[k + 1] = p;
      return;
    }
  }
  topk[0] = p;
}

template <typename T, int beam_size>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p) {
  for (int k = beam_size - 2; k >= 0; k--) {
    if (topk[k] < p) {
      topk[k + 1] = topk[k];
    } else {
      topk[k + 1] = p;
      return;
    }
  }
  topk[0] = p;
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
                                        int dim, int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < src[idx]) {
      Pair<T> tmp(src[idx], idx);
      AddTo<T>(topk, tmp, beam_size);
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
                                        int dim, const Pair<T>& max,
                                        int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < src[idx]) {
      Pair<T> tmp(src[idx], idx);
      if (tmp < max) {
        AddTo<T>(topk, tmp, beam_size);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
                                        int idx, int dim, int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < val[idx]) {
      Pair<T> tmp(val[idx], col[idx]);
      AddTo<T>(topk, tmp, beam_size);
    }
    idx += BlockSize;
  }
}

template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
                                        int idx, int dim, const Pair<T>& max,
                                        int beam_size) {
  while (idx < dim) {
    if (topk[beam_size - 1] < val[idx]) {
      Pair<T> tmp(val[idx], col[idx]);
      if (tmp < max) {
        AddTo<T>(topk, tmp, beam_size);
      }
    }
    idx += BlockSize;
  }
}

template <typename T, int MaxLength, int BlockSize>
139
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
武毅 已提交
140
                                              int beam_size, const T* src,
141 142
                                              bool* firstStep, bool* is_empty,
                                              Pair<T>* max, int dim,
武毅 已提交
143
                                              const int tid) {
144 145 146 147
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
武毅 已提交
148 149 150
      GetTopK<T, BlockSize>(topk, src, tid, dim, length);
    } else {
      for (int k = 0; k < MaxLength; k++) {
151 152
        if (k < MaxLength - (*beam)) {
          topk[k] = topk[k + *beam];
武毅 已提交
153
        } else {
154
          topk[k].set(-static_cast<T>(INFINITY), -1);
武毅 已提交
155 156
        }
      }
157 158
      if (!(*is_empty)) {
        GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
武毅 已提交
159 160 161 162
                              length);
      }
    }

163
    *max = topk[MaxLength - 1];
164
    if ((*max).v == -static_cast<T>(1)) *is_empty = true;
165
    *beam = 0;
武毅 已提交
166 167 168 169
  }
}

template <typename T, int MaxLength, int BlockSize>
170
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
武毅 已提交
171
                                              int beam_size, const T* val,
172 173
                                              int* col, bool* firstStep,
                                              bool* is_empty, Pair<T>* max,
武毅 已提交
174
                                              int dim, const int tid) {
175 176 177 178
  if (*beam > 0) {
    int length = (*beam) < beam_size ? *beam : beam_size;
    if (*firstStep) {
      *firstStep = false;
武毅 已提交
179 180 181
      GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
    } else {
      for (int k = 0; k < MaxLength; k++) {
182 183
        if (k < MaxLength - *beam) {
          topk[k] = topk[k + *beam];
武毅 已提交
184
        } else {
185
          topk[k].set(-static_cast<T>(INFINITY), -1);
武毅 已提交
186 187
        }
      }
188 189
      if (!(*is_empty)) {
        GetTopK<T, BlockSize>(topk + MaxLength - *beam, val, col, tid, dim, max,
武毅 已提交
190 191 192 193
                              length);
      }
    }

194 195 196
    *max = topk[MaxLength - 1];
    if ((*max).v == -1) *is_empty = true;
    *beam = 0;
武毅 已提交
197 198 199 200 201 202
  }
}

template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
                                            Pair<T> topk[], T** topVal,
203
                                            int64_t** topIds, int* beam, int* k,
武毅 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
                                            const int tid, const int warp) {
  while (true) {
    __syncthreads();
    if (tid < BlockSize / 2) {
      if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
        maxid[tid] = tid + BlockSize / 2;
      } else {
        maxid[tid] = tid;
      }
    }
    __syncthreads();
    for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
      if (tid < stride) {
        if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
          maxid[tid] = maxid[tid + stride];
        }
      }
      __syncthreads();
    }
    __syncthreads();

    if (tid == 0) {
      **topVal = sh_topk[maxid[0]].v;
      **topIds = sh_topk[maxid[0]].id;
      (*topVal)++;
      (*topIds)++;
    }
231 232
    if (tid == maxid[0]) (*beam)++;
    if (--(*k) == 0) break;
武毅 已提交
233 234 235
    __syncthreads();

    if (tid == maxid[0]) {
236 237
      if (*beam < MaxLength) {
        sh_topk[tid] = topk[*beam];
武毅 已提交
238 239
      }
    }
C
chengduoZH 已提交
240
    // NOTE(zcd): temporary solution
C
chengduoZH 已提交
241 242 243
    unsigned mask = 0u;
    CREATE_SHFL_MASK(mask, true);

武毅 已提交
244
    if (maxid[0] / 32 == warp) {
C
chengduoZH 已提交
245 246
      if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
          MaxLength)
C
chengduoZH 已提交
247
        break;
武毅 已提交
248 249 250 251 252 253 254 255 256 257 258 259
    }
  }
}

/**
 * Each block compute one sample.
 * In a block:
 * 1. every thread get top MaxLength value;
 * 2. merge to sh_topk, block reduce and get max value;
 * 3. go to the second setp, until one thread's topk value is null;
 * 4. go to the first setp, until get the topk value.
 */
260

武毅 已提交
261
template <typename T, int MaxLength, int BlockSize>
F
fengjiayi 已提交
262
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
263 264
                             const T* src, int lds, int dim, int k,
                             int grid_dim, int num) {
武毅 已提交
265 266 267 268
  __shared__ Pair<T> sh_topk[BlockSize];
  const int tid = threadIdx.x;
  const int warp = threadIdx.x / 32;

269 270
  const int bid = blockIdx.x;
  for (int i = bid; i < num; i += grid_dim) {
Q
qingqing01 已提交
271 272 273 274
    int top_num = k;
    __shared__ int maxid[BlockSize / 2];
    T* out = output + i * output_stride;
    int64_t* inds = indices + i * k;
275 276 277 278 279 280
    Pair<T> topk[MaxLength];
    int beam = MaxLength;
    Pair<T> max;
    bool is_empty = false;
    bool firststep = true;

Q
qingqing01 已提交
281
    for (int j = 0; j < MaxLength; j++) {
282
      topk[j].set(-static_cast<T>(INFINITY), -1);
283
    }
Q
qingqing01 已提交
284
    while (top_num) {
285 286
      ThreadGetTopK<T, MaxLength, BlockSize>(
          topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
武毅 已提交
287

288
      sh_topk[tid] = topk[0];
Q
qingqing01 已提交
289 290
      BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
                                           &beam, &top_num, tid, warp);
291
    }
武毅 已提交
292
  }
293 294 295 296 297 298 299 300 301 302 303
}

inline static int GetDesiredBlockDim(int dim) {
  if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
武毅 已提交
304 305 306
  }
}

307 308 309 310 311 312 313 314 315 316 317 318
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
  case (dim): {                        \
    constexpr auto kBlockDim = (dim);  \
    __VA_ARGS__;                       \
  } break

#define FIXED_BLOCK_DIM(...)                \
  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)

武毅 已提交
319
template <typename T>
Y
Yu Yang 已提交
320
class TopkOpCUDAKernel : public framework::OpKernel<T> {
武毅 已提交
321 322 323
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
324
                   "It must use CUDAPlace.");
武毅 已提交
325 326 327 328 329 330 331 332
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* indices = ctx.Output<Tensor>("Indices");
    size_t k = static_cast<int>(ctx.Attr<int>("k"));

    const T* input_data = input->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());
    // FIXME(typhoonzero): data is always converted to type T?
F
fengjiayi 已提交
333
    int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
武毅 已提交
334

Q
qingqing01 已提交
335 336 337 338 339
    framework::DDim inputdims = input->dims();
    const size_t input_height = framework::product(
        framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
    const size_t input_width = inputdims[inputdims.size() - 1];

武毅 已提交
340 341 342 343 344
    if (k > input_width) k = input_width;

    // NOTE: pass lds and dim same to input width.
    // NOTE: old matrix implementation of stride is different to eigen.
    // TODO(typhoonzero): refine this kernel.
345 346 347 348 349 350 351
    const int kMaxHeight = 2048;
    int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
    auto& dev_ctx = ctx.cuda_device_context();
    switch (GetDesiredBlockDim(input_width)) {
      FIXED_BLOCK_DIM(
          KeMatrixTopK<T, 5,
                       kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
Q
qingqing01 已提交
352 353
              output_data, k, indices_data, input_data, input_width,
              input_width, static_cast<int>(k), gridx, input_height));
354 355 356
      default:
        PADDLE_THROW("Error");
    }
武毅 已提交
357 358 359
  }
};

360 361 362
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM

武毅 已提交
363 364 365
}  // namespace operators
}  // namespace paddle

366 367 368 369
REGISTER_OP_CUDA_KERNEL(
    top_k, paddle::operators::TopkOpCUDAKernel<float>,
    paddle::operators::TopkOpCUDAKernel<double>,
    paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);