diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 57c2a3ba7b8b65e67ea95642d17c8ab85f33b62c..999e2510a8db5859b248cf6582106411a02bb55f 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -73,6 +73,13 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). "When axis < 0, the actual axis will be the |axis|'th " "counting backwards. Default -1, the last dimension.") .SetDefault(-1); + AddAttr( + "descending", + "(bool, default false) The descending attribute is a flag to tell" + "algorithm how to sort the input data." + "If descending is true, will sort by descending order," + "else if false, sort by ascending order. Default value is false.") + .SetDefault(false); } }; diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 7a11bf7808e4fed1fc997f863065537995b59a4b..8eb56e70b59ca486601862b501eca0b5731d2b31 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -58,11 +58,12 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { } } -// Default use ascending sort +// Sort by flag descending, True: descending. False: Ascending. +// Default is false. template -void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, - const Tensor* input, Tensor* output, Tensor* indices, - const IndType num_rows, const IndType num_cols) { +void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, + Tensor* output, Tensor* indices, const IndType num_rows, + const IndType num_cols, const bool descending) { auto cu_stream = ctx.stream(); Tensor input_indices; @@ -113,12 +114,20 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, cub::CountingInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); - auto err = cub::DeviceSegmentedRadixSort::SortPairs( - nullptr, temp_storage_bytes, inp, sorted_out_ptr, - input_indices.data(), sorted_indices_ptr, num_cols * num_rows, - num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, - cu_stream); - + cudaError_t err; + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } else { + err = cub::DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } PADDLE_ENFORCE_CUDA_SUCCESS( err, "ArgSortOP failed as could not launch " @@ -129,11 +138,19 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx, Tensor temp_storage; temp_storage.mutable_data(ctx.GetPlace(), temp_storage_bytes); - err = cub::DeviceSegmentedRadixSort::SortPairs( - temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, - input_indices.data(), sorted_indices_ptr, num_cols * num_rows, - num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, - cu_stream); + if (descending) { + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } else { + err = cub::DeviceSegmentedRadixSort::SortPairs( + temp_storage.data(), temp_storage_bytes, inp, sorted_out_ptr, + input_indices.data(), sorted_indices_ptr, num_cols * num_rows, + num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8, + cu_stream); + } PADDLE_ENFORCE_CUDA_SUCCESS( err, @@ -151,6 +168,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); int axis = ctx.Attr("axis"); + bool descending = ctx.Attr("descending"); auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; @@ -164,14 +182,8 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); const int64_t input_width = in_dims[in_dims.size() - 1]; const auto& dev_ctx = ctx.cuda_device_context(); - if (input_width < INT_MAX && input_height < INT_MAX) { - ArgFullSortAscending(dev_ctx, input, output, indices, - static_cast(input_height), - static_cast(input_width)); - } else { - ArgFullSortAscending(dev_ctx, input, output, indices, - input_height, input_width); - } + ArgFullSort(dev_ctx, input, output, indices, input_height, + input_width, descending); } else { // if not full sort, do transpose first std::vector trans; @@ -205,29 +217,15 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { T* out_data = output->mutable_data(ctx.GetPlace()); Tensor tmp_indices; - if (input_height < INT_MAX && input_width < INT_MAX) { - // temp indices for sorting - tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); - indices->mutable_data(ctx.GetPlace()); - - ArgFullSortAscending( - dev_ctx, &trans_inp, &tmp_out, &tmp_indices, - static_cast(input_height), static_cast(input_width)); - - TransCompute( - ndims, dev_ctx, tmp_indices, indices, trans); - } else { - // temp indices for sorting - tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); - indices->mutable_data(ctx.GetPlace()); - - ArgFullSortAscending(dev_ctx, &trans_inp, &tmp_out, - &tmp_indices, input_height, - input_width); - - TransCompute( - ndims, dev_ctx, tmp_indices, indices, trans); - } + // temp indices for sorting + tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); + indices->mutable_data(ctx.GetPlace()); + + ArgFullSort(dev_ctx, &trans_inp, &tmp_out, &tmp_indices, + input_height, input_width, descending); + + TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); // transpose back TransCompute(ndims, dev_ctx, tmp_out, output, trans); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index 7e9112cfb7cbe5f783b04729fb4dff3676c922bc..cd9f9a9ea2007b60863e6f70e621069e5851614b 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -16,11 +16,58 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/transpose_op.h" namespace paddle { namespace operators { +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + +using Tensor = framework::Tensor; + +template +static void FullSort(Type input_height, Type input_width, int input_dim, + const framework::Tensor* input, T* t_out, Type* t_indices, + bool descending) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + std::vector> col_vec; + col_vec.reserve(input_width); + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + for (Type j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(e_input(j), j)); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(e_input(i, j), j)); + } + } + std::sort(col_vec.begin(), col_vec.end(), + [&](const std::pair& l, const std::pair& r) { + if (descending) + return l.first > r.first; + else + return l.first < r.first; + }); + + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + j] = col_vec[j].first; + t_indices[i * input_width + j] = col_vec[j].second; + } + } +} template class ArgsortKernel : public framework::OpKernel { public: @@ -29,50 +76,68 @@ class ArgsortKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); int axis = ctx.Attr("axis"); + bool descending = ctx.Attr("descending"); auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; - const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - - int64_t groups = input->numel() / in_dims[axis]; - int64_t stride = (axis == in_dims.size() - 1) - ? 1 - : framework::product(framework::slice_ddim( - in_dims, axis + 1, in_dims.size())); - - for (int64_t i = 0; i < groups; ++i) { - int64_t idx = i; - std::vector shape_vec(in_dims.size(), 0); - for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { - if (dim != axis) { - shape_vec[dim] = idx % in_dims[dim]; - idx /= in_dims[dim]; - } - } - int64_t start_index = shape_vec[0]; - for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { - start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; - } + // Do full sort + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; - std::vector org_index_vec(in_dims[axis], start_index); - for (int64_t j = 1; j < in_dims[axis]; ++j) { - org_index_vec[j] += j * stride; + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); + FullSort(input_height, input_width, in_dims.size(), input, + out_data, ids_data, descending); + } else { + // If not full sort do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; } - std::sort(org_index_vec.begin(), org_index_vec.end(), - [in_data](const int64_t v1, const int64_t v2) { - return in_data[v1] < in_data[v2]; - }); + Tensor trans_inp; + trans_inp.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + auto& dev_ctx = ctx.template device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *input, + &trans_inp, trans); - for (size_t j = 0; j < org_index_vec.size(); ++j) { - int64_t index = start_index + j * stride; - out_data[index] = in_data[org_index_vec[j]]; - ids_data[index] = (org_index_vec[j] - start_index) / stride; - } + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + T* t_out = tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + output->mutable_data(ctx.GetPlace()); + + Tensor tmp_indices; + + auto* t_ind = + tmp_indices.mutable_data(trans_dims, ctx.GetPlace()); + + FullSort(input_height, input_width, in_dims.size(), + &trans_inp, t_out, t_ind, descending); + + indices->mutable_data(ctx.GetPlace()); + TransCompute( + ndims, dev_ctx, tmp_indices, indices, trans); + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, + output, trans); } } }; diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 8728fd9d21db6a13ee98e46ea331221b88a6d813..74bc7731a93dec045d03bde46627bbd57d11daca 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -53,7 +53,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < block_span) ? shared[lane] : static_cast(0.0f); val = warpReduceSum(val, mask); return val; diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 3366851711f70515805a4fcd75b3660a5828baff..c4f247236286e4d74a37f67def80ab38dc7c46c8 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -802,7 +802,7 @@ def argmax(x, axis=0): return out -def argsort(input, axis=-1, name=None): +def argsort(input, axis=-1, descending=False, name=None): """ This OP sorts the input along the given axis, and returns sorted output data Varibale and its corresponding index Variable with the same shape as @@ -814,6 +814,9 @@ def argsort(input, axis=-1, name=None): axis(int, optional): Axis to compute indices along. The effective range is [-R, R), where R is Rank(x). when axis<0, it works the same way as axis+R. Default is 0. + descending(bool, optional) : Descending is a flag, if set to true, + algorithm will sort by descending order, else sort by + ascending order. Default is false. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -879,7 +882,8 @@ def argsort(input, axis=-1, name=None): inputs={'X': input}, outputs={'Out': out, 'Indices': ids}, - attrs={'axis': axis}) + attrs={'axis': axis, + 'descending': descending}) return out, ids diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py old mode 100644 new mode 100755 index 81c9984badc5c271d2f9fb4c594ffdaff878dced..89ff5d7101a9aeaa529acf26f00670d29665ebdf --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -24,14 +24,24 @@ class TestArgsortOp(OpTest): def setUp(self): self.init_axis() self.init_datatype() + self.init_direction() x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype) + self.attrs = {'axis': self.axis, 'descending': self.descending} if self.axis < 0: self.axis = self.axis + len(x.shape) - self.indices = np.argsort(x, kind='quicksort', axis=self.axis) - self.out = np.sort(x, kind='quicksort', axis=self.axis) + if self.descending: + self.indices = np.flip( + np.argsort( + x, kind='quicksort', axis=self.axis), self.axis) + self.out = np.flip( + np.sort( + x, kind='quicksort', axis=self.axis), self.axis) + else: + self.indices = np.argsort(x, kind='quicksort', axis=self.axis) + self.out = np.sort(x, kind='quicksort', axis=self.axis) + self.op_type = "argsort" self.inputs = {'X': x} - self.attrs = {'axis': self.axis} self.outputs = {'Indices': self.indices, 'Out': self.out} def init_axis(self): @@ -40,6 +50,9 @@ class TestArgsortOp(OpTest): def init_datatype(self): self.dtype = "float32" + def init_direction(self): + self.descending = False + def test_check_output(self): self.check_output() @@ -103,5 +116,35 @@ class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16): self.axis = -4 +class TestArgsortOpDescendingAxis(TestArgsortOp): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0(TestArgsortOpAxis0): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1(TestArgsortOpAxis1): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2(TestArgsortOpAxis2): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1(TestArgsortOpAxisNeg1): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2): + def init_direction(self): + self.descending = True + + if __name__ == "__main__": unittest.main()