argsort_kernel.cu 10.7 KB
Newer Older
L
Linjie Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/argsort_kernel.h"

L
Linjie Chen 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#include <thrust/copy.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
33
#include "paddle/phi/kernels/funcs/math_function.h"
L
Linjie Chen 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/transpose_kernel.h"

#ifdef __HIPCC__
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<phi::dtype::float16>
    : radix_key_codec_integral<phi::dtype::float16, uint16_t> {};
}  // namespace detail
}  // namespace rocprim
#else
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<phi::dtype::float16>
    : BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
}  // namespace cub
#endif

namespace phi {

// Iter for move to next row
struct SegmentOffsetIter {
  EIGEN_DEVICE_FUNC
  explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
    return idx * num_cols_;
  }

  int num_cols_;
};

template <typename T>
69
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
L
Linjie Chen 已提交
70 71 72 73 74 75 76 77 78 79 80 81
  int col_id = threadIdx.x;
  int row_id = blockIdx.x;

  for (T j = row_id; j < num_rows; j += gridDim.x) {
    for (T i = col_id; i < num_cols; i += blockDim.x) {
      indices[j * num_cols + i] = i;
    }
  }
}

// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
82 83 84 85 86 87 88
template <typename T, typename IndType>
void ArgFullSort(const phi::GPUContext& ctx,
                 const DenseTensor* input,
                 DenseTensor* output,
                 DenseTensor* indices,
                 const IndType num_rows,
                 const IndType num_cols,
L
Linjie Chen 已提交
89 90 91
                 const bool descending) {
  auto cu_stream = ctx.stream();
  DenseTensor input_indices;
92
  const std::vector<IndType> dims = {num_rows, num_cols};
L
Linjie Chen 已提交
93 94
  auto dim = phi::make_ddim(dims);
  input_indices.Resize(dim);
95
  ctx.template Alloc<IndType>(&input_indices);
L
Linjie Chen 已提交
96 97
  size_t temp_storage_bytes = -1;

98
  auto ComputeBlockSize = [](IndType col) {
L
Linjie Chen 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    if (col > 512)
      return 1024;
    else if (col > 256 && col <= 512)
      return 512;
    else if (col > 128 && col <= 256)
      return 256;
    else if (col > 64 && col <= 128)
      return 128;
    else
      return 64;
  };

  int block_size = ComputeBlockSize(num_cols);
  int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
  // actually, int num_rows < max_grid_size
  int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
  // Init a index array
  FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
117
      input_indices.data<IndType>(), num_rows, num_cols);
L
Linjie Chen 已提交
118

119 120 121 122 123
  T* sorted_out_ptr;
  IndType* sorted_indices_ptr;
  const T* inp = input->data<T>();
  T* out = ctx.template Alloc<T>(output);
  IndType* ind = ctx.template Alloc<IndType>(indices);
L
Linjie Chen 已提交
124 125 126 127
  sorted_out_ptr = out;
  sorted_indices_ptr = ind;

  // create iter for counting input
128
  cub::CountingInputIterator<IndType> counting_iter(0);
L
Linjie Chen 已提交
129
  // segment_offset is used for move to next row
130
  cub::TransformInputIterator<IndType,
L
Linjie Chen 已提交
131
                              SegmentOffsetIter,
132
                              cub::CountingInputIterator<IndType>>
L
Linjie Chen 已提交
133 134 135 136
      segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));

  gpuError_t err;
  if (descending) {
137 138 139 140 141 142 143 144 145 146 147 148 149 150
    err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
        nullptr,
        temp_storage_bytes,
        inp,
        sorted_out_ptr,
        input_indices.data<IndType>(),
        sorted_indices_ptr,
        num_cols * num_rows,
        num_rows,
        segment_offsets_t,
        segment_offsets_t + 1,
        0,
        sizeof(T) * 8,
        cu_stream);
L
Linjie Chen 已提交
151
  } else {
152 153 154 155 156 157 158 159 160 161 162 163 164 165
    err =
        cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
                                                 temp_storage_bytes,
                                                 inp,
                                                 sorted_out_ptr,
                                                 input_indices.data<IndType>(),
                                                 sorted_indices_ptr,
                                                 num_cols * num_rows,
                                                 num_rows,
                                                 segment_offsets_t,
                                                 segment_offsets_t + 1,
                                                 0,
                                                 sizeof(T) * 8,
                                                 cu_stream);
L
Linjie Chen 已提交
166
  }
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  PADDLE_ENFORCE_GPU_SUCCESS(err);

  DenseTensor temp_storage;
  int64_t temp_size = temp_storage_bytes;
  temp_storage.Resize({temp_size});
  ctx.template Alloc<uint8_t>(&temp_storage);

  if (descending) {
    err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
        temp_storage.data<uint8_t>(),
        temp_storage_bytes,
        inp,
        sorted_out_ptr,
        input_indices.data<IndType>(),
        sorted_indices_ptr,
        num_cols * num_rows,
        num_rows,
        segment_offsets_t,
        segment_offsets_t + 1,
        0,
        sizeof(T) * 8,
        cu_stream);
  } else {
    err =
        cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data<uint8_t>(),
                                                 temp_storage_bytes,
                                                 inp,
                                                 sorted_out_ptr,
                                                 input_indices.data<IndType>(),
                                                 sorted_indices_ptr,
                                                 num_cols * num_rows,
                                                 num_rows,
                                                 segment_offsets_t,
                                                 segment_offsets_t + 1,
                                                 0,
                                                 sizeof(T) * 8,
                                                 cu_stream);
  }

  PADDLE_ENFORCE_GPU_SUCCESS(err);
L
Linjie Chen 已提交
207 208 209
}

template <typename T, typename Context>
210 211
void ArgsortKernel(const Context& dev_ctx,
                   const DenseTensor& input,
L
Linjie Chen 已提交
212 213
                   int axis,
                   bool descending,
214 215
                   DenseTensor* output,
                   DenseTensor* indices) {
L
Linjie Chen 已提交
216
  auto in_dims = input.dims();
217
  auto rank = in_dims.size();
L
Linjie Chen 已提交
218
  axis = (axis < 0) ? (in_dims.size() + axis) : axis;
219
  const T* in_data = input.data<T>();
L
Linjie Chen 已提交
220
  auto size = input.numel();
221 222
  T* out_data = dev_ctx.template Alloc<T>(output);
  int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);
L
Linjie Chen 已提交
223

224 225 226 227 228 229
  if (rank == 0) {
    phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
    phi::funcs::set_constant(dev_ctx, indices, 0);
    return;
  }

L
Linjie Chen 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
  // Use thrust for parallel acceleration when the input size is equal to the
  // length of the ‘axis’ dimension.
  // Compared to the following 'Special case for full sort', ascending sort is
  // 34 times faster and descending sort is 31 times faster.
  if (size == in_dims[axis]) {
    thrust::sequence(thrust::device, ids_data, ids_data + size);
    thrust::copy(thrust::device, in_data, in_data + size, out_data);
    thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
    if (descending) {
      thrust::reverse(thrust::device, out_data, out_data + size);
      thrust::reverse(thrust::device, ids_data, ids_data + size);
    }
    return;
  }

  // Special case for full sort, speedup ~190x.
  if (axis == -1 || axis + 1 == in_dims.size()) {
    const int64_t input_height =
        phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
    const int64_t input_width = in_dims[in_dims.size() - 1];
250 251 252 253 254 255 256
    ArgFullSort<T, int64_t>(dev_ctx,
                            &input,
                            output,
                            indices,
                            input_height,
                            input_width,
                            descending);
L
Linjie Chen 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
  } else {
    // if not full sort, do transpose first
    std::vector<int> trans;
    for (int i = 0; i < axis; i++) {
      trans.push_back(i);
    }
    trans.push_back(in_dims.size() - 1);
    for (int i = axis + 1; i < in_dims.size() - 1; i++) {
      trans.push_back(i);
    }
    trans.push_back(axis);
    phi::DDim trans_dims(in_dims);
    for (int i = 0; i < trans.size(); i++) {
      trans_dims[i] = in_dims[trans[i]];
    }

    DenseTensor trans_inp;
    trans_inp.Resize(trans_dims);
275
    T* trans_inp_data = dev_ctx.template Alloc<T>(&trans_inp);
L
Linjie Chen 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
    // Do transpose
    TransposeKernel<T, Context>(dev_ctx, input, trans, &trans_inp);

    const int64_t input_height =
        phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
    const int64_t input_width = trans_dims[trans_dims.size() - 1];

    DenseTensor tmp_out;
    tmp_out.Resize(trans_dims);
    dev_ctx.template Alloc<T>(&tmp_out);

    DenseTensor tmp_indices;
    // temp indices for sorting
    tmp_indices.Resize(trans_dims);
    dev_ctx.template Alloc<int64_t>(&tmp_indices);
    dev_ctx.template Alloc<int64_t>(indices);

293 294 295 296 297 298 299
    ArgFullSort<T, int64_t>(dev_ctx,
                            &trans_inp,
                            &tmp_out,
                            &tmp_indices,
                            input_height,
                            input_width,
                            descending);
L
Linjie Chen 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

    TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
    // transpose back
    TransposeKernel<T, Context>(dev_ctx, tmp_out, trans, output);
    return;
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(argsort,
                   GPU,
                   ALL_LAYOUT,
                   phi::ArgsortKernel,
                   float,
                   double,
                   int,
                   int64_t,
318 319 320
                   phi::dtype::float16) {
  kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}