diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..aead4e2e00f75710d13cad52d0cd31e37e16eb0c --- /dev/null +++ b/paddle/fluid/operators/argsort_op.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/argsort_op.h" + +namespace paddle { +namespace operators { + +class ArgsortOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ArgsortOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ArgsortOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Indices"), + "Output(Indices) of ArgsortOp should not be null."); + + auto in_dims = ctx->GetInputDim("X"); + int axis = static_cast(ctx->Attrs().Get("axis")); + + auto num_dims = in_dims.size(); + PADDLE_ENFORCE(axis < num_dims, + "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " + "dimension %d.", + axis, num_dims); + PADDLE_ENFORCE(axis >= 0 || axis == -1, + "Attr(axis) %d of ArgsortOp must be nonnegative or equal to " + "-1.", + axis); + + ctx->SetOutputDim("Out", in_dims); + ctx->SetOutputDim("Indices", in_dims); + ctx->ShareLoD("X", "Out"); + ctx->ShareLoD("X", "Indices"); + } +}; + +class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input of Argsort op."); + AddOutput("Out", "(Tensor) The sorted tensor of Argsort op."); + AddOutput("Indices", + "(Tensor) The indices of a tensor giving the sorted order."); + AddComment(R"DOC( +Argsort operator + +Performs sorting on the input tensor along the given axis and outputs two +tensors, Output(Out) and Output(Indices). They reserve the same shape +with Input(X), and Output(Out) represents the sorted tensor while +Output(Indices) gives the sorted order along the given axis Attr(axis). + + )DOC"); + AddAttr("axis", + "(int, default -1) The axis along which to sort the tensor, " + "default -1, the last dimension.") + .SetDefault(-1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(argsort, + ops::ArgsortKernel, + ops::ArgsortKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a9fe22c4ce1649d19d380bd4ac7f7949c26d3865 --- /dev/null +++ b/paddle/fluid/operators/argsort_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class ArgsortKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = static_cast(ctx.Attr("axis")); + + auto in_dims = input->dims(); + axis = (axis == -1) ? (in_dims.size() - 1) : axis; + + const T* in_data = input->data(); + T* out_data = output->mutable_data(ctx.GetPlace()); + int64_t* idx_data = indices->mutable_data(ctx.GetPlace()); + + int64_t part_dims_prod = input->numel() / in_dims[axis]; + for (int64_t i = 0; i < part_dims_prod; ++i) { + int64_t idx = i; + std::vector idx_vec(in_dims.size(), 0); + for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { + if (dim != axis) { + idx_vec[dim] = idx % in_dims[dim]; + idx /= in_dims[dim]; + } + } + std::vector> in_vec; + std::vector org_index_vec(in_dims[axis], 0); + for (int64_t j = 0; j < in_dims[axis]; ++j) { + idx_vec[axis] = j; + int64_t index = idx_vec[0]; + for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { + index = index * in_dims[dim + 1] + idx_vec[dim + 1]; + } + in_vec.push_back(std::pair(in_data[index], j)); + org_index_vec[j] = index; + } + + std::sort( + in_vec.begin(), in_vec.end(), + [](const std::pair& v1, const std::pair& v2) { + return v1.first < v2.first; + }); + + for (size_t j = 0; j < org_index_vec.size(); ++j) { + int64_t index = org_index_vec[j]; + out_data[index] = in_vec[j].first; + idx_data[index] = in_vec[j].second; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6995621ba8ca64ddb47dae7de07253af8997934e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestArgsortOp(OpTest): + def setUp(self): + self.init_axis() + x = np.random.random((2, 3, 4, 5)).astype("float32") + self.indices = np.argsort(x, kind='quicksort', axis=self.axis) + self.out = np.sort(x, kind='quicksort', axis=self.axis) + self.op_type = "argsort" + self.inputs = {'X': x} + self.attrs = {'axis': self.axis} + self.outputs = {'Indices': self.indices, 'Out': self.out} + + def init_axis(self): + self.axis = -1 + + def test_check_output(self): + self.check_output() + + +class TestArgsortOpAxis0(TestArgsortOp): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpAxis1(TestArgsortOp): + def init_axis(self): + self.axis = 1 + + +if __name__ == "__main__": + unittest.main()