提交 94e72ea6 编写于 作者: Y Yibing Liu

Support more negative axes in argsort_op

上级 42645ff7
......@@ -37,10 +37,10 @@ class ArgsortOp : public framework::OperatorWithKernel {
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X) "
"dimension %d.",
axis, num_dims);
PADDLE_ENFORCE(axis >= 0 || axis == -1,
"Attr(axis) %d of ArgsortOp must be nonnegative or equal to "
"-1.",
axis);
PADDLE_ENFORCE(in_dims.size() + axis >= 0,
"Attr(axis) %d of ArgsortOp plus the number of Input(X)'s "
"dimensions %d must be nonnegative.",
axis, in_dims.size());
ctx->SetOutputDim("Out", in_dims);
ctx->SetOutputDim("Indices", in_dims);
......@@ -53,9 +53,12 @@ class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Argsort op.");
AddOutput("Out", "(Tensor) The sorted tensor of Argsort op.");
AddOutput("Out",
"(Tensor) The sorted tensor of Argsort op, with the same "
"shape as Input(X).");
AddOutput("Indices",
"(Tensor) The indices of a tensor giving the sorted order.");
"(Tensor) The indices of a tensor giving the sorted order, with "
"the same shape as Input(X).");
AddComment(R"DOC(
Argsort operator
......@@ -66,8 +69,9 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
)DOC");
AddAttr<int>("axis",
"(int, default -1) The axis along which to sort the tensor, "
"default -1, the last dimension.")
"(int, default -1) The axis along which to sort the tensor. "
"When axis < 0, the actual axis will be the |axis|'th "
"counting backwards. Default -1, the last dimension.")
.SetDefault(-1);
}
};
......
......@@ -103,7 +103,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis");
auto in_dims = input->dims();
axis = (axis == -1) ? (in_dims.size() - 1) : axis;
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
......
......@@ -31,7 +31,7 @@ class ArgsortKernel : public framework::OpKernel<T> {
int axis = static_cast<int>(ctx.Attr<int>("axis"));
auto in_dims = input->dims();
axis = (axis == -1) ? (in_dims.size() - 1) : axis;
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
......
......@@ -21,6 +21,8 @@ class TestArgsortOp(OpTest):
def setUp(self):
self.init_axis()
x = np.random.random((2, 3, 4, 5)).astype("float32")
if self.axis < 0:
self.axis = self.axis + len(x.shape)
self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis)
self.op_type = "argsort"
......@@ -45,5 +47,10 @@ class TestArgsortOpAxis1(TestArgsortOp):
self.axis = 1
class TestArgsortOpAxisNeg2(TestArgsortOp):
def init_axis(self):
self.axis = -2
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册