diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst index 264506a68ae17d081dd58ef4794bf7723f6d021c..d443c49657b92583e527035f49e74462cf41487d 100644 --- a/doc/fluid/api/layers.rst +++ b/doc/fluid/api/layers.rst @@ -1468,6 +1468,14 @@ argmax .. autofunction:: paddle.fluid.layers.argmax :noindex: +.. _api_fluid_layers_argsort: + +argsort +------- + +.. autofunction:: paddle.fluid.layers.argsort + :noindex: + .. _api_fluid_layers_ones: ones diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2f5a2545701991263c1ef842e9275b1edbfd2ca --- /dev/null +++ b/paddle/fluid/operators/argsort_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/argsort_op.h" + +namespace paddle { +namespace operators { + +class ArgsortOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ArgsortOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ArgsortOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Indices"), + "Output(Indices) of ArgsortOp should not be null."); + + auto in_dims = ctx->GetInputDim("X"); + int axis = ctx->Attrs().Get("axis"); + + auto num_dims = in_dims.size(); + PADDLE_ENFORCE(axis < num_dims, + "Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s " + "rank %d.", + axis, num_dims); + PADDLE_ENFORCE(axis >= -num_dims, + "Attr(axis) %d of ArgsortOp must be not less than " + "-rank(Input(X)) (%d).", + axis, num_dims); + + ctx->SetOutputDim("Out", in_dims); + ctx->SetOutputDim("Indices", in_dims); + ctx->ShareLoD("X", "Out"); + ctx->ShareLoD("X", "Indices"); + } +}; + +class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input of Argsort op."); + AddOutput("Out", + "(Tensor) The sorted tensor of Argsort op, with the same " + "shape as Input(X)."); + AddOutput("Indices", + "(Tensor) The indices of a tensor giving the sorted order, with " + "the same shape as Input(X)."); + AddComment(R"DOC( +Argsort operator + +Performs sorting on the input tensor along the given axis and outputs two +tensors, Output(Out) and Output(Indices). They reserve the same shape +with Input(X), and Output(Out) represents the sorted tensor while +Output(Indices) gives the sorted order along the given axis Attr(axis). + + )DOC"); + AddAttr("axis", + "(int, default -1) The axis along which to sort the tensor. " + "When axis < 0, the actual axis will be the |axis|'th " + "counting backwards. Default -1, the last dimension.") + .SetDefault(-1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(argsort, + ops::ArgsortKernel, + ops::ArgsortKernel); diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d5199aae7da4eed5afa6b8bd64c04a540b915d4 --- /dev/null +++ b/paddle/fluid/operators/argsort_op.cu @@ -0,0 +1,151 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/argsort_op.h" +#include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::PADDLE_CUDA_NUM_THREADS; + +const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid + +__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, + int axis, int64_t n, int64_t* trg_idx, + int64_t* med_ids) { + int64_t index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + int64_t shape_out_axis[kMaxRank - 1] = {0}; + int64_t dims_out_axis[kMaxRank - 1] = {0}; + int64_t tmp = index; + int64_t pos_in_axis = 0; + int64_t i = dims_size - 2; + int64_t dim_axis = 0; + for (int64_t j = dims_size - 1; j >= 0; --j) { + int64_t dim = in_dims[j]; + if (j != axis) { + shape_out_axis[i] = tmp % dim; + dims_out_axis[i] = dim; + i--; + } else { + dim_axis = dim; + pos_in_axis = tmp % dim_axis; + } + tmp /= dim; + } + int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0; + for (int64_t j = 0; j < dims_size - 2; ++j) { + group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1]; + } + + int64_t traget_idx = group * dim_axis + pos_in_axis; + trg_idx[index] = traget_idx; + med_ids[traget_idx] = pos_in_axis; + } +} + +template +__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, + T* med_out) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + med_out[trg_idx[index]] = in[index]; + } +} + +template +__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out, + int64_t* med_ids) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < groups) { + thrust::sort_by_key(thrust::device, med_out + index * axis_dim, + med_out + axis_dim * (1 + index), + med_ids + index * axis_dim); + } +} + +template +__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids, + const int64_t* trg_idx, int64_t n, T* out, + int64_t* indices) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + out[index] = med_out[trg_idx[index]]; + indices[index] = med_ids[trg_idx[index]]; + } +} + +template +class ArgsortOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = ctx.Attr("axis"); + + auto in_dims = input->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + const T* in_data = input->data(); + T* out_data = output->mutable_data(ctx.GetPlace()); + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); + + int64_t numel = input->numel(); + int64_t groups = numel / in_dims[axis]; + + std::vector in_dims_vec = vectorize(in_dims); + thrust::device_vector in_dims_dev(in_dims_vec.begin(), + in_dims_vec.end()); + int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); + // Mediate tensor for sorting data and indices + Tensor mediate_output, mediate_indices; + T* med_out_data = + mediate_output.mutable_data(input->dims(), ctx.GetPlace()); + int64_t* med_ids_data = + mediate_indices.mutable_data(in_dims, ctx.GetPlace()); + // Target index of each element along the given axis in the mediate tensors + Tensor trg_idx_t; + int64_t* trg_idx = trg_idx_t.mutable_data(in_dims, ctx.GetPlace()); + + auto stream = ctx.cuda_device_context().stream(); + const int num_threads = PADDLE_CUDA_NUM_THREADS; + + ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); + + PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_data, trg_idx, numel, med_out_data); + + Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_dims[axis], groups, med_out_data, med_ids_data); + + PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0, + stream>>>(med_out_data, med_ids_data, trg_idx, numel, + out_data, ids_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel, + paddle::operators::ArgsortOpCUDAKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7e9112cfb7cbe5f783b04729fb4dff3676c922bc --- /dev/null +++ b/paddle/fluid/operators/argsort_op.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ArgsortKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = ctx.Attr("axis"); + + auto in_dims = input->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + const T* in_data = input->data(); + T* out_data = output->mutable_data(ctx.GetPlace()); + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); + + int64_t groups = input->numel() / in_dims[axis]; + int64_t stride = (axis == in_dims.size() - 1) + ? 1 + : framework::product(framework::slice_ddim( + in_dims, axis + 1, in_dims.size())); + + for (int64_t i = 0; i < groups; ++i) { + int64_t idx = i; + std::vector shape_vec(in_dims.size(), 0); + for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { + if (dim != axis) { + shape_vec[dim] = idx % in_dims[dim]; + idx /= in_dims[dim]; + } + } + + int64_t start_index = shape_vec[0]; + for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { + start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; + } + + std::vector org_index_vec(in_dims[axis], start_index); + for (int64_t j = 1; j < in_dims[axis]; ++j) { + org_index_vec[j] += j * stride; + } + + std::sort(org_index_vec.begin(), org_index_vec.end(), + [in_data](const int64_t v1, const int64_t v2) { + return in_data[v1] < in_data[v2]; + }); + + for (size_t j = 0; j < org_index_vec.size(); ++j) { + int64_t index = start_index + j * stride; + out_data[index] = in_data[org_index_vec[j]]; + ids_data[index] = (org_index_vec[j] - start_index) / stride; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index ce5f08de623c8e4572599f8088ecae2e4821cce0..b6614ecf3bc16e73683f4991779769049c6800ed 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -33,6 +33,7 @@ __all__ = [ 'fill_constant', 'argmin', 'argmax', + 'argsort', 'ones', 'zeros', 'reverse', @@ -444,6 +445,58 @@ def argmax(x, axis=0): return out +def argsort(input, axis=-1, name=None): + """ + Performs sorting on the input Variable along the given axis, and outputs + sorted data Varibale and its corresponding index Variable with the same + shape as :attr:`input`. + + .. code-block:: text + + For example, the given axis is -1 and the input Variable + + input = [[0.15849551, 0.45865775, 0.8563702 ], + [0.12070083, 0.28766365, 0.18776911]], + + after argsort, the sorted Vairable becomes + + out = [[0.15849551, 0.45865775, 0.8563702 ], + [0.12070083, 0.18776911, 0.28766365]], + + and the sorted indices along the given axis turn outs to be + + indices = [[0, 1, 2], + [0, 2, 1]] + + Args: + input(Variable): The input Variable for sorting. + axis(int): The axis along which to sort the input Variable. When + :attr:`axis` < 0, the actual axis will be :attr:`axis` + + rank(:attr:`input`). Default -1, the last dimension. + name(str|None): (optional) A name for this layer. If set None, the + layer will be named automatically. + + Returns: + tuple: A tuple of sorted data Variable and the sorted indices. + + Examples: + .. code-block:: python + + input = fluid.layers.data(data=[2, 3]) + out, indices = fluid.layers.argsort(input, axis=0) + """ + helper = LayerHelper("argsort", **locals()) + out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True) + ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True) + helper.append_op( + type='argsort', + inputs={'X': input}, + outputs={'Out': out, + 'Indices': ids}, + attrs={'axis': axis}) + return out, ids + + def ones(shape, dtype, force_cpu=False): """ **ones** diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b29a102a3880406156481fdac54ca7043d3415db --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestArgsortOp(OpTest): + def setUp(self): + self.init_axis() + x = np.random.random((2, 3, 4, 5, 10)).astype("float32") + if self.axis < 0: + self.axis = self.axis + len(x.shape) + self.indices = np.argsort(x, kind='quicksort', axis=self.axis) + self.out = np.sort(x, kind='quicksort', axis=self.axis) + self.op_type = "argsort" + self.inputs = {'X': x} + self.attrs = {'axis': self.axis} + self.outputs = {'Indices': self.indices, 'Out': self.out} + + def init_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output() + + +class TestArgsortOpAxis0(TestArgsortOp): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpAxis1(TestArgsortOp): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxisNeg2(TestArgsortOp): + def init_axis(self): + self.axis = -2 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 9d4b2d4434f3ec9cb62acd8b0e08dfea16279320..842d34c07e94a79e3351347e2528ecc478cc56dc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -419,6 +419,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(iou) print(str(program)) + def test_argsort(self): + program = Program() + with program_guard(program): + data = layers.data(name='x', shape=[2, 3, 3], dtype="float32") + out, ids = layers.argsort(input=data, axis=1) + self.assertIsNotNone(out) + self.assertIsNotNone(ids) + print(str(program)) + if __name__ == '__main__': unittest.main()