提交 94e72ea6 编写于 作者: Y Yibing Liu

Support more negative axes in argsort_op

上级 42645ff7
...@@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel { ...@@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel {
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X) " "Attr(axis) %d of ArgsortOp is out of bounds for Input(X) "
"dimension %d.", "dimension %d.",
axis, num_dims); axis, num_dims);
PADDLE_ENFORCE(axis >= 0 || axis == -1, PADDLE_ENFORCE(in_dims.size() + axis >= 0,
"Attr(axis) %d of ArgsortOp must be nonnegative or equal to " "Attr(axis) %d of ArgsortOp plus the number of Input(X)'s "
"-1.", "dimensions %d must be nonnegative.",
axis); axis, in_dims.size());
ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Out", in_dims);
ctx->SetOutputDim("Indices", in_dims); ctx->SetOutputDim("Indices", in_dims);
...@@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of Argsort op."); AddInput("X", "(Tensor) The input of Argsort op.");
AddOutput("Out", "(Tensor) The sorted tensor of Argsort op."); AddOutput("Out",
"(Tensor) The sorted tensor of Argsort op, with the same "
"shape as Input(X).");
AddOutput("Indices", AddOutput("Indices",
"(Tensor) The indices of a tensor giving the sorted order."); "(Tensor) The indices of a tensor giving the sorted order, with "
"the same shape as Input(X).");
AddComment(R"DOC( AddComment(R"DOC(
Argsort operator Argsort operator
...@@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). ...@@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
)DOC"); )DOC");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1) The axis along which to sort the tensor, " "(int, default -1) The axis along which to sort the tensor. "
"default -1, the last dimension.") "When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension.")
.SetDefault(-1); .SetDefault(-1);
} }
}; };
......
...@@ -103,7 +103,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> { ...@@ -103,7 +103,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis == -1) ? (in_dims.size() - 1) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>(); const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace()); T* out_data = output->mutable_data<T>(ctx.GetPlace());
......
...@@ -31,7 +31,7 @@ class ArgsortKernel : public framework::OpKernel<T> { ...@@ -31,7 +31,7 @@ class ArgsortKernel : public framework::OpKernel<T> {
int axis = static_cast<int>(ctx.Attr<int>("axis")); int axis = static_cast<int>(ctx.Attr<int>("axis"));
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis == -1) ? (in_dims.size() - 1) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>(); const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace()); T* out_data = output->mutable_data<T>(ctx.GetPlace());
......
...@@ -21,6 +21,8 @@ class TestArgsortOp(OpTest): ...@@ -21,6 +21,8 @@ class TestArgsortOp(OpTest):
def setUp(self): def setUp(self):
self.init_axis() self.init_axis()
x = np.random.random((2, 3, 4, 5)).astype("float32") x = np.random.random((2, 3, 4, 5)).astype("float32")
if self.axis < 0:
self.axis = self.axis + len(x.shape)
self.indices = np.argsort(x, kind='quicksort', axis=self.axis) self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis) self.out = np.sort(x, kind='quicksort', axis=self.axis)
self.op_type = "argsort" self.op_type = "argsort"
...@@ -45,5 +47,10 @@ class TestArgsortOpAxis1(TestArgsortOp): ...@@ -45,5 +47,10 @@ class TestArgsortOpAxis1(TestArgsortOp):
self.axis = 1 self.axis = 1
class TestArgsortOpAxisNeg2(TestArgsortOp):
def init_axis(self):
self.axis = -2
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册