argsort_op.cu 5.5 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;

29 30
const int kMaxRank = 9;  // The max rank of a tensor allowed in Fluid

Y
Yibing Liu 已提交
31 32 33 34 35
__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
                                 int axis, int64_t n, int64_t* trg_idx,
                                 int64_t* med_ids) {
  int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
  if (index < n) {
36 37
    int64_t shape_out_axis[kMaxRank - 1] = {0};
    int64_t dims_out_axis[kMaxRank - 1] = {0};
Y
Yibing Liu 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    int64_t tmp = index;
    int64_t pos_in_axis = 0;
    int64_t i = dims_size - 2;
    int64_t dim_axis = 0;
    for (int64_t j = dims_size - 1; j >= 0; --j) {
      int64_t dim = in_dims[j];
      if (j != axis) {
        shape_out_axis[i] = tmp % dim;
        dims_out_axis[i] = dim;
        i--;
      } else {
        dim_axis = dim;
        pos_in_axis = tmp % dim_axis;
      }
      tmp /= dim;
    }
    int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0;
    for (int64_t j = 0; j < dims_size - 2; ++j) {
      group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1];
    }

    int64_t traget_idx = group * dim_axis + pos_in_axis;
    trg_idx[index] = traget_idx;
    med_ids[traget_idx] = pos_in_axis;
  }
}

Y
Yibing Liu 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
template <typename T>
__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n,
                              T* med_out) {
  int index = threadIdx.x + blockDim.x * blockIdx.x;
  if (index < n) {
    med_out[trg_idx[index]] = in[index];
  }
}

template <typename T>
__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out,
                     int64_t* med_ids) {
  int index = threadIdx.x + blockDim.x * blockIdx.x;
  if (index < groups) {
    thrust::sort_by_key(thrust::device, med_out + index * axis_dim,
                        med_out + axis_dim * (1 + index),
                        med_ids + index * axis_dim);
  }
}

template <typename T>
__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids,
                                   const int64_t* trg_idx, int64_t n, T* out,
                                   int64_t* indices) {
  int index = threadIdx.x + blockDim.x * blockIdx.x;
  if (index < n) {
    out[index] = med_out[trg_idx[index]];
    indices[index] = med_ids[trg_idx[index]];
  }
}

template <typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    auto* indices = ctx.Output<Tensor>("Indices");
    int axis = ctx.Attr<int>("axis");

    auto in_dims = input->dims();
106
    axis = (axis < 0) ? (in_dims.size() + axis) : axis;
Y
Yibing Liu 已提交
107 108 109 110 111 112 113 114

    const T* in_data = input->data<T>();
    T* out_data = output->mutable_data<T>(ctx.GetPlace());
    int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());

    int64_t numel = input->numel();
    int64_t groups = numel / in_dims[axis];

Y
Yibing Liu 已提交
115 116 117 118 119 120
    std::vector<int64_t> in_dims_vec = vectorize(in_dims);
    thrust::device_vector<int64_t> in_dims_dev(in_dims_vec.begin(),
                                               in_dims_vec.end());
    int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data());
    // Mediate tensor for sorting data and indices
    Tensor mediate_output, mediate_indices;
Y
Yibing Liu 已提交
121 122
    T* med_out_data =
        mediate_output.mutable_data<T>(input->dims(), ctx.GetPlace());
Y
Yibing Liu 已提交
123 124 125 126 127
    int64_t* med_ids_data =
        mediate_indices.mutable_data<int64_t>(in_dims, ctx.GetPlace());
    // Target index of each element along the given axis in the mediate tensors
    Tensor trg_idx_t;
    int64_t* trg_idx = trg_idx_t.mutable_data<int64_t>(in_dims, ctx.GetPlace());
Y
Yibing Liu 已提交
128

129 130
    auto stream = ctx.cuda_device_context().stream();
    const int num_threads = PADDLE_CUDA_NUM_THREADS;
Y
Yibing Liu 已提交
131 132 133

    ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
        in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data);
Y
Yibing Liu 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151

    PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>(
        in_data, trg_idx, numel, med_out_data);

    Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>(
        in_dims[axis], groups, med_out_data, med_ids_data);

    PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0,
                         stream>>>(med_out_data, med_ids_data, trg_idx, numel,
                                   out_data, ids_data);
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
                        paddle::operators::ArgsortOpCUDAKernel<double>);