未验证 提交 dc42e3c4 编写于 作者: W wawltor 提交者: GitHub

Fix the argsort and sort api for the paddle api2.0 (#25514)

Fix the argsort and sort op for the api2.0, and update the api 
上级 42189be6
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.imperative as imperative
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import numpy as np import numpy as np
import six import six
...@@ -321,58 +322,83 @@ class TestArgsortOpDescendingAxisNeg2GPU(TestArgsortOpAxisNeg2GPU): ...@@ -321,58 +322,83 @@ class TestArgsortOpDescendingAxisNeg2GPU(TestArgsortOpAxisNeg2GPU):
self.descending = True self.descending = True
class TestSortOnCPU(TestArgsortOpCPU): class TestArgsortErrorOnCPU(unittest.TestCase):
def init_place(self): def setUp(self):
self.place = core.CPUPlace() self.place = core.CPUPlace()
def test_out(self): def test_error(self):
self.init_place() def test_fluid_var_type():
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32") x = [1]
output = fluid.layers.argsort(input=x)
res = fluid.data(name="output", shape=[2, 3, 4], dtype="float32") self.assertRaises(TypeError, test_fluid_var_type)
output = paddle.tensor.sort(input=input, out=res)
exe = fluid.Executor(self.place)
data = np.array(
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32')
result = exe.run(feed={'input': data}, fetch_list=[res, output[0]])
self.assertEqual((result[0] == result[1]).all(), True) def test_paddle_var_type():
with fluid.program_guard(fluid.Program()):
x = [1]
output = paddle.argsort(input=x)
self.assertRaises(TypeError, test_paddle_var_type)
class TestSortOnGPU(TestSortOnCPU): class TestArgsortErrorOnGPU(TestArgsortErrorOnCPU):
def init_place(self): def setUp(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
else: else:
self.place = core.CPUPlace() self.place = core.CPUPlace()
class TestArgsortErrorOnCPU(unittest.TestCase): class TestArgsort(unittest.TestCase):
def init_place(self): def setUp(self):
self.place = core.CPUPlace() if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
self.data = np.random.rand(2, 3, 4).astype("float32")
def test_error(self): def test_api_0(self):
self.init_place()
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.argsort(x=input)
exe = fluid.Executor(self.place)
result, = exe.run(feed={'input': self.data}, fetch_list=[output])
np_result = np.argsort(self.data)
self.assertEqual((result == np_result).all(), True)
def test_input_type(): def test_api_1(self):
x = [1] with fluid.program_guard(fluid.Program()):
output = fluid.layers.argsort(input=x) input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.argsort(x=input, axis=1)
self.assertRaises(TypeError, test_input_type) exe = fluid.Executor(self.place)
result, = exe.run(feed={'input': self.data}, fetch_list=[output])
np_result = np.argsort(self.data, axis=1)
self.assertEqual((result == np_result).all(), True)
class TestArgsortErrorOnGPU(TestArgsortErrorOnCPU): class TestArgsortDygraph(unittest.TestCase):
def init_place(self): def setUp(self):
self.input_data = np.random.rand(10, 10)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
else: else:
self.place = core.CPUPlace() self.place = core.CPUPlace()
def test_api_0(self):
with imperative.guard(self.place):
var_x = imperative.to_variable(self.input_data)
out = paddle.argsort(var_x)
self.assertEqual((np.argsort(self.input_data) == out.numpy()).all(),
True)
def test_api_1(self):
with imperative.guard(self.place):
var_x = imperative.to_variable(self.input_data)
out = paddle.argsort(var_x, axis=-1)
self.assertEqual(
(np.argsort(
self.input_data, axis=-1) == out.numpy()).all(),
True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.imperative as imperative
import paddle.fluid.layers as layers
import numpy as np
import six
import paddle.fluid.core as core
class TestSortOnCPU(unittest.TestCase):
def setUp(self):
self.place = core.CPUPlace()
def test_api_0(self):
with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.sort(x=input)
exe = fluid.Executor(self.place)
data = np.array(
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32')
result, = exe.run(feed={'input': data}, fetch_list=[output[0]])
np_result = np.sort(result)
self.assertEqual((result == np_result).all(), True)
def test_api_1(self):
with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.sort(x=input, axis=1)
exe = fluid.Executor(self.place)
data = np.array(
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32')
result, = exe.run(feed={'input': data}, fetch_list=[output[0]])
np_result = np.sort(result, axis=1)
self.assertEqual((result == np_result).all(), True)
class TestSortOnGPU(TestSortOnCPU):
def init_place(self):
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
class TestSortDygraph(unittest.TestCase):
def setUp(self):
self.input_data = np.random.rand(10, 10)
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
def test_api_0(self):
with imperative.guard(self.place):
var_x = imperative.to_variable(self.input_data)
out = paddle.sort(var_x)
self.assertEqual((np.sort(self.input_data) == out[0].numpy()).all(),
True)
def test_api_1(self):
with imperative.guard(self.place):
var_x = imperative.to_variable(self.input_data)
out = paddle.sort(var_x, axis=-1)
self.assertEqual(
(np.sort(
self.input_data, axis=-1) == out[0].numpy()).all(),
True)
...@@ -19,7 +19,6 @@ from ..fluid import core, layers ...@@ -19,7 +19,6 @@ from ..fluid import core, layers
# TODO: define searching & indexing functions of a tensor # TODO: define searching & indexing functions of a tensor
from ..fluid.layers import argmin #DEFINE_ALIAS from ..fluid.layers import argmin #DEFINE_ALIAS
from ..fluid.layers import argsort #DEFINE_ALIAS
from ..fluid.layers import has_inf #DEFINE_ALIAS from ..fluid.layers import has_inf #DEFINE_ALIAS
from ..fluid.layers import has_nan #DEFINE_ALIAS from ..fluid.layers import has_nan #DEFINE_ALIAS
from ..fluid.layers import topk #DEFINE_ALIAS from ..fluid.layers import topk #DEFINE_ALIAS
...@@ -42,6 +41,92 @@ __all__ = [ ...@@ -42,6 +41,92 @@ __all__ = [
from paddle.common_ops_import import * from paddle.common_ops_import import *
def argsort(x, axis=-1, descending=False, name=None):
"""
:alias_main: paddle.argsort
:alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort
This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as ``x``.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way
as axis+R. Default is 0.
descending(bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: sorted indices(with the same shape as ``x``
and with data type int64).
Examples:
.. code-block:: python
import paddle
import paddle.imperative as imperative
import numpy as np
paddle.enable_imperative()
input_array = np.array([[[5,8,9,5],
[0,0,1,7],
[6,9,2,4]],
[[5,2,4,2],
[4,7,7,9],
[1,7,0,6]]]).astype(np.float32)
x = imperative.to_variable(input_array)
out1 = paddle.argsort(x=x, axis=-1)
out2 = paddle.argsort(x=x, axis=0)
out3 = paddle.argsort(x=x, axis=1)
print(out1.numpy())
#[[[0 3 1 2]
# [0 1 2 3]
# [2 3 0 1]]
# [[1 3 2 0]
# [0 1 2 3]
# [2 0 3 1]]]
print(out2.numpy())
#[[[0 1 1 1]
# [0 0 0 0]
# [1 1 1 0]]
# [[1 0 0 0]
# [1 1 1 1]
# [0 0 0 1]]]
print(out3.numpy())
#[[[1 1 1 2]
# [0 0 2 0]
# [2 2 0 1]]
# [[2 0 2 0]
# [1 1 0 2]
# [0 2 1 1]]]
"""
if in_dygraph_mode():
_, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending)
return ids
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'argsort')
helper = LayerHelper("argsort", **locals())
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
ids = helper.create_variable_for_type_inference(
VarDesc.VarType.INT64, stop_gradient=True)
helper.append_op(
type='argsort',
inputs={'X': x},
outputs={'Out': out,
'Indices': ids},
attrs={'axis': axis,
'descending': descending})
return ids
def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None): def argmax(input, axis=None, dtype=None, out=None, keepdims=False, name=None):
""" """
:alias_main: paddle.argmax :alias_main: paddle.argmax
...@@ -291,19 +376,16 @@ def nonzero(input, as_tuple=False): ...@@ -291,19 +376,16 @@ def nonzero(input, as_tuple=False):
return tuple(list_out) return tuple(list_out)
def sort(input, axis=-1, descending=False, out=None, name=None): def sort(x, axis=-1, descending=False, name=None):
""" """
:alias_main: paddle.sort :alias_main: paddle.sort
:alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort :alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort
This OP sorts the input along the given axis, and returns sorted output This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as data Tensor and its corresponding index Tensor with the same shape as ``x``.
:attr:`input`.
**NOTICE**: The Variable in the output of this OP has gradient. You could\
set Variable :attr:`stop_gradient`.
Args: Args:
input(Variable): An input N-D Tensor with type float32, float64, int16, x(Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8. int32, int64, uint8.
axis(int, optional): Axis to compute indices along. The effective range axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way is [-R, R), where R is Rank(x). when axis<0, it works the same way
...@@ -311,71 +393,70 @@ def sort(input, axis=-1, descending=False, out=None, name=None): ...@@ -311,71 +393,70 @@ def sort(input, axis=-1, descending=False, out=None, name=None):
descending(bool, optional) : Descending is a flag, if set to true, descending(bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by algorithm will sort by descending order, else sort by
ascending order. Default is false. ascending order. Default is false.
out(Variable, optional): The default value is None. Optional output
which can be any created Variable that meets the requirements to
store the result of operation. if out is None, a new Varibale will
be create to store the result.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
tuple: A tuple of sorted data Variable(with the same shape and data tuple: A tuple of sorted data tensor(with the same shape and data
type as input) and the sorted indices(with the same shape as input's type as ``x``) and the sorted indices(with the same shape as ``x``
and with data type int64). and with data type int64).
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid import paddle.imperative as imperative
import numpy as np import numpy as np
in1 = np.array([[[5,8,9,5],
paddle.enable_imperative()
input_array = np.array([[[5,8,9,5],
[0,0,1,7], [0,0,1,7],
[6,9,2,4]], [6,9,2,4]],
[[5,2,4,2], [[5,2,4,2],
[4,7,7,9], [4,7,7,9],
[1,7,0,6]]]).astype(np.float32) [1,7,0,6]]]).astype(np.float32)
with fluid.dygraph.guard(): x = imperative.to_variable(input_array)
x = fluid.dygraph.to_variable(in1) out1 = paddle.sort(x=x, axis=-1)
out1 = paddle.sort(input=x, axis=-1) out2 = paddle.sort(x=x, axis=0)
out2 = paddle.sort(input=x, axis=0) out3 = paddle.sort(x=x, axis=1)
out3 = paddle.sort(input=x, axis=1) print(out1[0].numpy())
print(out1[0].numpy()) #[[[5. 5. 8. 9.]
# [[[5. 5. 8. 9.] # [0. 0. 1. 7.]
# [0. 0. 1. 7.] # [2. 4. 6. 9.]]
# [2. 4. 6. 9.]] # [[2. 2. 4. 5.]
# [[2. 2. 4. 5.] # [4. 7. 7. 9.]
# [4. 7. 7. 9.] # [0. 1. 6. 7.]]]
# [0. 1. 6. 7.]]] print(out1[1].numpy())
print(out1[1].numpy()) #[[[0 3 1 2]
# [[[0 3 1 2] # [0 1 2 3]
# [0 1 2 3] # [2 3 0 1]]
# [2 3 0 1]] # [[1 3 2 0]
# [[1 3 2 0] # [0 1 2 3]
# [0 1 2 3] # [2 0 3 1]]]
# [2 0 3 1]]] print(out2[0].numpy())
print(out2[0].numpy()) #[[[5. 2. 4. 2.]
# [[[5. 2. 4. 2.] # [0. 0. 1. 7.]
# [0. 0. 1. 7.] # [1. 7. 0. 4.]]
# [1. 7. 0. 4.]] # [[5. 8. 9. 5.]
# [[5. 8. 9. 5.] # [4. 7. 7. 9.]
# [4. 7. 7. 9.] # [6. 9. 2. 6.]]]
# [6. 9. 2. 6.]]] print(out3[0].numpy())
print(out3[0].numpy()) #[[[0. 0. 1. 4.]
# [[[0. 0. 1. 4.] # [5. 8. 2. 5.]
# [5. 8. 2. 5.] # [6. 9. 9. 7.]]
# [6. 9. 9. 7.]] # [[1. 2. 0. 2.]
# [[1. 2. 0. 2.] # [4. 7. 4. 6.]
# [4. 7. 4. 6.] # [5. 7. 7. 9.]]]
# [5. 7. 7. 9.]]]
""" """
if in_dygraph_mode():
out, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending)
return out, ids
helper = LayerHelper("sort", **locals()) helper = LayerHelper("sort", **locals())
if out is None: out = helper.create_variable_for_type_inference(
out = helper.create_variable_for_type_inference( dtype=x.dtype, stop_gradient=False)
dtype=input.dtype, stop_gradient=False)
ids = helper.create_variable_for_type_inference( ids = helper.create_variable_for_type_inference(
VarDesc.VarType.INT64, stop_gradient=True) VarDesc.VarType.INT64, stop_gradient=True)
helper.append_op( helper.append_op(
type='argsort', type='argsort',
inputs={'X': input}, inputs={'X': x},
outputs={'Out': out, outputs={'Out': out,
'Indices': ids}, 'Indices': ids},
attrs={'axis': axis, attrs={'axis': axis,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册