diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 1a0b303817a48ba50f7ce917f94251886c12d229..7a11bf7808e4fed1fc997f863065537995b59a4b 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -14,82 +14,133 @@ limitations under the License. */ #include #include +#include "cub/cub.cuh" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +// set cub base traits in order to handle float16 +namespace cub { +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub + namespace paddle { namespace operators { using Tensor = framework::Tensor; -using platform::PADDLE_CUDA_NUM_THREADS; - -const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid - -__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, - int axis, int64_t n, int64_t* trg_idx, - int64_t* med_ids) { - int64_t index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < n) { - int64_t shape_out_axis[kMaxRank - 1] = {0}; - int64_t dims_out_axis[kMaxRank - 1] = {0}; - int64_t tmp = index; - int64_t pos_in_axis = 0; - int64_t i = dims_size - 2; - int64_t dim_axis = 0; - for (int64_t j = dims_size - 1; j >= 0; --j) { - int64_t dim = in_dims[j]; - if (j != axis) { - shape_out_axis[i] = tmp % dim; - dims_out_axis[i] = dim; - i--; - } else { - dim_axis = dim; - pos_in_axis = tmp % dim_axis; - } - tmp /= dim; - } - int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0; - for (int64_t j = 0; j < dims_size - 2; ++j) { - group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1]; - } - int64_t traget_idx = group * dim_axis + pos_in_axis; - trg_idx[index] = traget_idx; - med_ids[traget_idx] = pos_in_axis; - } -} +// Iter for move to next row +struct SegmentOffsetIter { + EIGEN_DEVICE_FUNC + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} -template -__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, - T* med_out) { - int index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < n) { - med_out[trg_idx[index]] = in[index]; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; } -} + + int num_cols_; +}; template -__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out, - int64_t* med_ids) { - int index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < groups) { - thrust::sort_by_key(thrust::device, med_out + index * axis_dim, - med_out + axis_dim * (1 + index), - med_ids + index * axis_dim); +static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; + } } } -template -__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids, - const int64_t* trg_idx, int64_t n, T* out, - int64_t* indices) { - int index = threadIdx.x + blockDim.x * blockIdx.x; - if (index < n) { - out[index] = med_out[trg_idx[index]]; - indices[index] = med_ids[trg_idx[index]]; - } +// Default use ascending sort +template +void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, + const Tensor* input, Tensor* output, Tensor* indices, + const IndType num_rows, const IndType num_cols) { + auto cu_stream = ctx.stream(); + + Tensor input_indices; + + const std::vector dims = {num_rows, num_cols}; + auto dim = framework::make_ddim(dims); + input_indices.Resize(dim); + input_indices.mutable_data(ctx.GetPlace()); + + size_t temp_storage_bytes = -1; + + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + + int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; + // actually, int num_rows < max_grid_size + int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + // Init a index array + FillIndex<<>>( + input_indices.data(), num_rows, num_cols); + + T* sorted_out_ptr; + IndType* sorted_indices_ptr; + + const T* inp = input->data(); + T* out = output->mutable_data(ctx.GetPlace()); + IndType* ind = indices->mutable_data(ctx.GetPlace()); + + sorted_out_ptr = out; + sorted_indices_ptr = ind; + + // create iter for counting input + cub::CountingInputIterator counting_iter(0); + // segment_offset is used for move to next row + cub::TransformInputIterator> + segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); + + auto err = cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + + PADDLE_ENFORCE_CUDA_SUCCESS( + err, + "ArgSortOP failed as could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate" + "temp_storage_bytes, status:%s.", + temp_storage_bytes, cudaGetErrorString(err)); + + Tensor temp_storage; + temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); + + err = cub::DeviceSegmentedRadixSort::SortPairs( + temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + + PADDLE_ENFORCE_CUDA_SUCCESS( + err, + "ArgSortOP failed as could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " + "temp_storage_bytes:%d status:%s.", + temp_storage_bytes, cudaGetErrorString(err)); } template @@ -104,47 +155,91 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - const T* in_data = input->data(); - T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - int64_t numel = input->numel(); int64_t groups = numel / in_dims[axis]; - std::vector in_dims_vec = vectorize(in_dims); - thrust::device_vector in_dims_dev(in_dims_vec.begin(), - in_dims_vec.end()); - int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); - // Mediate tensor for sorting data and indices - Tensor mediate_output, mediate_indices; - T* med_out_data = - mediate_output.mutable_data(input->dims(), ctx.GetPlace()); - int64_t* med_ids_data = - mediate_indices.mutable_data(in_dims, ctx.GetPlace()); - // Target index of each element along the given axis in the mediate tensors - Tensor trg_idx_t; - int64_t* trg_idx = trg_idx_t.mutable_data(in_dims, ctx.GetPlace()); - - auto stream = ctx.cuda_device_context().stream(); - const int num_threads = PADDLE_CUDA_NUM_THREADS; - - ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( - in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); - - PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( - in_data, trg_idx, numel, med_out_data); - - Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>( - in_dims[axis], groups, med_out_data, med_ids_data); - - PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0, - stream>>>(med_out_data, med_ids_data, trg_idx, numel, - out_data, ids_data); + // Special case for full sort, speedup ~190x. + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + const auto& dev_ctx = ctx.cuda_device_context(); + if (input_width < INT_MAX && input_height < INT_MAX) { + ArgFullSortAscending(dev_ctx, input, output, indices, + static_cast(input_height), + static_cast(input_width)); + } else { + ArgFullSortAscending(dev_ctx, input, output, indices, + input_height, input_width); + } + } else { + // if not full sort, do transpose first + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + Tensor trans_inp; + T* trans_inp_data = trans_inp.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + const auto& dev_ctx = ctx.cuda_device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *input, + &trans_inp, trans); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + T* out_data = output->mutable_data(ctx.GetPlace()); + + Tensor tmp_indices; + if (input_height < INT_MAX && input_width < INT_MAX) { + // temp indices for sorting + tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); + indices->mutable_data(ctx.GetPlace()); + + ArgFullSortAscending( + dev_ctx, &trans_inp, &tmp_out, &tmp_indices, + static_cast(input_height), static_cast(input_width)); + + TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); + } else { + // temp indices for sorting + tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); + indices->mutable_data(ctx.GetPlace()); + + ArgFullSortAscending(dev_ctx, &trans_inp, &tmp_out, + &tmp_indices, input_height, + input_width); + + TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); + } + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, + output, trans); + return; + } } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel, - paddle::operators::ArgsortOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + argsort, paddle::operators::ArgsortOpCUDAKernel, + paddle::operators::ArgsortOpCUDAKernel, + paddle::operators::ArgsortOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index 7bc6f2599d617b192908da9b57d0cd715019bd71..81c9984badc5c271d2f9fb4c594ffdaff878dced 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -17,12 +17,14 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core class TestArgsortOp(OpTest): def setUp(self): self.init_axis() - x = np.random.random((2, 3, 4, 5, 10)).astype("float32") + self.init_datatype() + x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype) if self.axis < 0: self.axis = self.axis + len(x.shape) self.indices = np.argsort(x, kind='quicksort', axis=self.axis) @@ -35,6 +37,9 @@ class TestArgsortOp(OpTest): def init_axis(self): self.axis = -1 + def init_datatype(self): + self.dtype = "float32" + def test_check_output(self): self.check_output() @@ -49,10 +54,54 @@ class TestArgsortOpAxis1(TestArgsortOp): self.axis = 1 +class TestArgsortOpAxis2(TestArgsortOp): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1(TestArgsortOp): + def init_axis(self): + self.axis = -1 + + class TestArgsortOpAxisNeg2(TestArgsortOp): def init_axis(self): self.axis = -2 +class TestArgsortOpFP16(TestArgsortOp): + def init_datatype(self): + if core.is_compiled_with_cuda(): + self.dtype = 'float16' + + def test_check_output(self): + pass + + def test_check_output_with_place(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + + +class TestArgsortOpFP16Axis0(TestArgsortOpFP16): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpFP16Axis2(TestArgsortOpFP16): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpFP16AxisNeg2(TestArgsortOpFP16): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16): + def init_axis(self): + self.axis = -4 + + if __name__ == "__main__": unittest.main()