argsort_op.cu 8.5 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/execution_policy.h>
#include <thrust/sort.h>
17
#include "cub/cub.cuh"
Y
Yibing Liu 已提交
18 19
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
20
#include "paddle/fluid/operators/transpose_op.h"
Y
Yibing Liu 已提交
21 22 23
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"

24 25 26 27 28 29 30 31
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
    : BaseTraits<FLOATING_POINT, true, false, uint16_t,
                 paddle::platform::float16> {};
}  // namespace cub

Y
Yibing Liu 已提交
32 33 34 35
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
Y
Yibing Liu 已提交
36

37 38 39 40
// Iter for move to next row
struct SegmentOffsetIter {
  EIGEN_DEVICE_FUNC
  explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
Y
Yibing Liu 已提交
41

42 43
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
    return idx * num_cols_;
Y
Yibing Liu 已提交
44
  }
45 46 47

  int num_cols_;
};
Y
Yibing Liu 已提交
48 49

template <typename T>
50 51 52 53 54 55 56 57
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
  int col_id = threadIdx.x;
  int row_id = blockIdx.x;

  for (T j = row_id; j < num_rows; j += gridDim.x) {
    for (T i = col_id; i < num_cols; i += blockDim.x) {
      indices[j * num_cols + i] = i;
    }
Y
Yibing Liu 已提交
58 59 60
  }
}

61 62
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
63
template <typename T, typename IndType>
64 65 66
void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
                 Tensor* output, Tensor* indices, const IndType num_rows,
                 const IndType num_cols, const bool descending) {
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  auto cu_stream = ctx.stream();

  Tensor input_indices;

  const std::vector<IndType> dims = {num_rows, num_cols};
  auto dim = framework::make_ddim(dims);
  input_indices.Resize(dim);
  input_indices.mutable_data<IndType>(ctx.GetPlace());

  size_t temp_storage_bytes = -1;

  auto ComputeBlockSize = [](IndType col) {
    if (col > 512)
      return 1024;
    else if (col > 256 && col <= 512)
      return 512;
    else if (col > 128 && col <= 256)
      return 256;
    else if (col > 64 && col <= 128)
      return 128;
    else
      return 64;
  };

  int block_size = ComputeBlockSize(num_cols);

  int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
  // actually, int num_rows < max_grid_size
  int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
  // Init a index array
  FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
      input_indices.data<IndType>(), num_rows, num_cols);

  T* sorted_out_ptr;
  IndType* sorted_indices_ptr;

  const T* inp = input->data<T>();
  T* out = output->mutable_data<T>(ctx.GetPlace());
  IndType* ind = indices->mutable_data<IndType>(ctx.GetPlace());

  sorted_out_ptr = out;
  sorted_indices_ptr = ind;

  // create iter for counting input
  cub::CountingInputIterator<IndType> counting_iter(0);
  // segment_offset is used for move to next row
  cub::TransformInputIterator<IndType, SegmentOffsetIter,
                              cub::CountingInputIterator<IndType>>
      segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));

117 118 119 120 121 122 123 124 125 126 127 128 129 130
  cudaError_t err;
  if (descending) {
    err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
        nullptr, temp_storage_bytes, inp, sorted_out_ptr,
        input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
        num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
        cu_stream);
  } else {
    err = cub::DeviceSegmentedRadixSort::SortPairs(
        nullptr, temp_storage_bytes, inp, sorted_out_ptr,
        input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
        num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
        cu_stream);
  }
131 132 133 134 135 136 137 138 139 140
  PADDLE_ENFORCE_CUDA_SUCCESS(
      err,
      "ArgSortOP failed as could not launch "
      "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate"
      "temp_storage_bytes, status:%s.",
      temp_storage_bytes, cudaGetErrorString(err));

  Tensor temp_storage;
  temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);

141 142 143 144 145 146 147 148 149 150 151 152 153
  if (descending) {
    err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
        temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
        input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
        num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
        cu_stream);
  } else {
    err = cub::DeviceSegmentedRadixSort::SortPairs(
        temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
        input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
        num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
        cu_stream);
  }
154 155 156 157 158 159 160

  PADDLE_ENFORCE_CUDA_SUCCESS(
      err,
      "ArgSortOP failed as could not launch "
      "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
      "temp_storage_bytes:%d status:%s.",
      temp_storage_bytes, cudaGetErrorString(err));
Y
Yibing Liu 已提交
161 162 163 164 165 166 167 168 169 170
}

template <typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* indices = ctx.Output<Tensor>("Indices");
    int axis = ctx.Attr<int>("axis");
171
    bool descending = ctx.Attr<bool>("descending");
Y
Yibing Liu 已提交
172 173

    auto in_dims = input->dims();
174
    axis = (axis < 0) ? (in_dims.size() + axis) : axis;
Y
Yibing Liu 已提交
175 176 177 178

    int64_t numel = input->numel();
    int64_t groups = numel / in_dims[axis];

179 180 181 182 183 184
    // Special case for full sort, speedup ~190x.
    if (axis == -1 || axis + 1 == in_dims.size()) {
      const int64_t input_height = framework::product(
          framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
      const int64_t input_width = in_dims[in_dims.size() - 1];
      const auto& dev_ctx = ctx.cuda_device_context();
185 186
      ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
                              input_width, descending);
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    } else {
      // if not full sort, do transpose first
      std::vector<int> trans;
      for (int i = 0; i < axis; i++) {
        trans.push_back(i);
      }
      trans.push_back(in_dims.size() - 1);
      for (int i = axis + 1; i < in_dims.size() - 1; i++) {
        trans.push_back(i);
      }
      trans.push_back(axis);
      framework::DDim trans_dims(in_dims);
      for (int i = 0; i < trans.size(); i++) {
        trans_dims[i] = in_dims[trans[i]];
      }

      Tensor trans_inp;
      T* trans_inp_data = trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
      int ndims = trans.size();
      const auto& dev_ctx = ctx.cuda_device_context();
      // Do transpose
      TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
                                                   &trans_inp, trans);

      const int64_t input_height = framework::product(
          framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
      const int64_t input_width = trans_dims[trans_dims.size() - 1];

      Tensor tmp_out;
      tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
      T* out_data = output->mutable_data<T>(ctx.GetPlace());

      Tensor tmp_indices;
220 221 222 223 224 225 226 227 228
      // temp indices for sorting
      tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
      indices->mutable_data<int64_t>(ctx.GetPlace());

      ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
                              input_height, input_width, descending);

      TransCompute<platform::CUDADeviceContext, int64_t>(
          ndims, dev_ctx, tmp_indices, indices, trans);
229 230 231 232 233
      // transpose back
      TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
                                                   output, trans);
      return;
    }
Y
Yibing Liu 已提交
234 235 236 237 238 239
  }
};

}  // namespace operators
}  // namespace paddle

240 241 242 243
REGISTER_OP_CUDA_KERNEL(
    argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
    paddle::operators::ArgsortOpCUDAKernel<double>,
    paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);