From 94e72ea6e7eba2f89533225f57626cfed93c0155 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 12 Jun 2018 06:31:01 -0700 Subject: [PATCH] Support more negative axes in argsort_op --- paddle/fluid/operators/argsort_op.cc | 20 +++++++++++-------- paddle/fluid/operators/argsort_op.cu | 2 +- paddle/fluid/operators/argsort_op.h | 2 +- .../fluid/tests/unittests/test_argsort_op.py | 7 +++++++ 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 2943d409a..8a44fd12c 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel { "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " "dimension %d.", axis, num_dims); - PADDLE_ENFORCE(axis >= 0 || axis == -1, - "Attr(axis) %d of ArgsortOp must be nonnegative or equal to " - "-1.", - axis); + PADDLE_ENFORCE(in_dims.size() + axis >= 0, + "Attr(axis) %d of ArgsortOp plus the number of Input(X)'s " + "dimensions %d must be nonnegative.", + axis, in_dims.size()); ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Indices", in_dims); @@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Argsort op."); - AddOutput("Out", "(Tensor) The sorted tensor of Argsort op."); + AddOutput("Out", + "(Tensor) The sorted tensor of Argsort op, with the same " + "shape as Input(X)."); AddOutput("Indices", - "(Tensor) The indices of a tensor giving the sorted order."); + "(Tensor) The indices of a tensor giving the sorted order, with " + "the same shape as Input(X)."); AddComment(R"DOC( Argsort operator @@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). )DOC"); AddAttr("axis", - "(int, default -1) The axis along which to sort the tensor, " - "default -1, the last dimension.") + "(int, default -1) The axis along which to sort the tensor. " + "When axis < 0, the actual axis will be the |axis|'th " + "counting backwards. Default -1, the last dimension.") .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index eac18ea3a..55ad4ce34 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -103,7 +103,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { int axis = ctx.Attr("axis"); auto in_dims = input->dims(); - axis = (axis == -1) ? (in_dims.size() - 1) : axis; + axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index 51d2b89f9..e13745c49 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -31,7 +31,7 @@ class ArgsortKernel : public framework::OpKernel { int axis = static_cast(ctx.Attr("axis")); auto in_dims = input->dims(); - axis = (axis == -1) ? (in_dims.size() - 1) : axis; + axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index 6995621ba..1d0aa82a6 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -21,6 +21,8 @@ class TestArgsortOp(OpTest): def setUp(self): self.init_axis() x = np.random.random((2, 3, 4, 5)).astype("float32") + if self.axis < 0: + self.axis = self.axis + len(x.shape) self.indices = np.argsort(x, kind='quicksort', axis=self.axis) self.out = np.sort(x, kind='quicksort', axis=self.axis) self.op_type = "argsort" @@ -45,5 +47,10 @@ class TestArgsortOpAxis1(TestArgsortOp): self.axis = 1 +class TestArgsortOpAxisNeg2(TestArgsortOp): + def init_axis(self): + self.axis = -2 + + if __name__ == "__main__": unittest.main() -- GitLab