diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 2943d409a2e42a01f3d6bbb28c8c89c2409c45f5..8a44fd12ce0eb61c87e6f5b244d8f19c88db49c9 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel { "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " "dimension %d.", axis, num_dims); - PADDLE_ENFORCE(axis >= 0 || axis == -1, - "Attr(axis) %d of ArgsortOp must be nonnegative or equal to " - "-1.", - axis); + PADDLE_ENFORCE(in_dims.size() + axis >= 0, + "Attr(axis) %d of ArgsortOp plus the number of Input(X)'s " + "dimensions %d must be nonnegative.", + axis, in_dims.size()); ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Indices", in_dims); @@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Argsort op."); - AddOutput("Out", "(Tensor) The sorted tensor of Argsort op."); + AddOutput("Out", + "(Tensor) The sorted tensor of Argsort op, with the same " + "shape as Input(X)."); AddOutput("Indices", - "(Tensor) The indices of a tensor giving the sorted order."); + "(Tensor) The indices of a tensor giving the sorted order, with " + "the same shape as Input(X)."); AddComment(R"DOC( Argsort operator @@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). )DOC"); AddAttr("axis", - "(int, default -1) The axis along which to sort the tensor, " - "default -1, the last dimension.") + "(int, default -1) The axis along which to sort the tensor. " + "When axis < 0, the actual axis will be the |axis|'th " + "counting backwards. Default -1, the last dimension.") .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index eac18ea3a0350697c8e1a96b71cc4a9068b26be8..55ad4ce340d0ff73323aa9951e9c8f52326aae4a 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -103,7 +103,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { int axis = ctx.Attr("axis"); auto in_dims = input->dims(); - axis = (axis == -1) ? (in_dims.size() - 1) : axis; + axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index 51d2b89f94d5b52bbfef2108127f90475915b835..e13745c4941cd25d3242a1c8f2860e836d80f73c 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -31,7 +31,7 @@ class ArgsortKernel : public framework::OpKernel { int axis = static_cast(ctx.Attr("axis")); auto in_dims = input->dims(); - axis = (axis == -1) ? (in_dims.size() - 1) : axis; + axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index 6995621ba8ca64ddb47dae7de07253af8997934e..1d0aa82a6b398592dc5b905be56591a1094b9e41 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -21,6 +21,8 @@ class TestArgsortOp(OpTest): def setUp(self): self.init_axis() x = np.random.random((2, 3, 4, 5)).astype("float32") + if self.axis < 0: + self.axis = self.axis + len(x.shape) self.indices = np.argsort(x, kind='quicksort', axis=self.axis) self.out = np.sort(x, kind='quicksort', axis=self.axis) self.op_type = "argsort" @@ -45,5 +47,10 @@ class TestArgsortOpAxis1(TestArgsortOp): self.axis = 1 +class TestArgsortOpAxisNeg2(TestArgsortOp): + def init_axis(self): + self.axis = -2 + + if __name__ == "__main__": unittest.main()