未验证 提交 c5c273c1 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix Using Tuple for Transpose in Dy2stat (#28574)

PaddleSeg uses tuple as parameter of transpose in dygraph code:
https://github.com/PaddlePaddle/PaddleSeg/blob/release/v0.7.0/dygraph/paddleseg/models/danet.py#L152

However, in dy2stat, static code doesn't support the perm as a tuple. This PR fixed it.
上级 2b1e7e5b
...@@ -5459,7 +5459,7 @@ def transpose(x, perm, name=None): ...@@ -5459,7 +5459,7 @@ def transpose(x, perm, name=None):
Args: Args:
x (Variable): The input Tensor. It is a N-D Tensor of data types float32, float64, int32. x (Variable): The input Tensor. It is a N-D Tensor of data types float32, float64, int32.
perm (list): Permute the input according to the data of perm. perm (list|tuple): Permute the input according to the data of perm.
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
...@@ -5492,14 +5492,12 @@ def transpose(x, perm, name=None): ...@@ -5492,14 +5492,12 @@ def transpose(x, perm, name=None):
.. code-block:: python .. code-block:: python
# use append_batch_size=False to avoid prepending extra import paddle
# batch size in shape
import paddle.fluid as fluid x = paddle.randn([2, 3, 4])
x = fluid.layers.data(name='x', shape=[2, 3, 4], x_transposed = paddle.transpose(x, perm=[1, 0, 2])
dtype='float32', append_batch_size=False) print(x_transposed.shape)
x_transposed = fluid.layers.transpose(x, perm=[1, 0, 2]) # [3L, 2L, 4L]
print x_transposed.shape
#(3L, 2L, 4L)
""" """
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -5509,8 +5507,9 @@ def transpose(x, perm, name=None): ...@@ -5509,8 +5507,9 @@ def transpose(x, perm, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
'transpose') 'transpose')
check_type(perm, 'perm', list, 'transpose') check_type(perm, 'perm', (list, tuple), 'transpose')
if isinstance(perm, tuple):
perm = list(perm)
if len(perm) != len(x.shape): if len(perm) != len(x.shape):
raise ValueError( raise ValueError(
"Input(perm) is the permutation of dimensions of Input(x), " "Input(perm) is the permutation of dimensions of Input(x), "
......
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
paddle.enable_static()
class TestTransposeOp(OpTest): class TestTransposeOp(OpTest):
def setUp(self): def setUp(self):
...@@ -113,6 +114,7 @@ class TestCase9(TestTransposeOp): ...@@ -113,6 +114,7 @@ class TestCase9(TestTransposeOp):
class TestTransposeOpError(unittest.TestCase): class TestTransposeOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float64') x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float64')
...@@ -149,6 +151,39 @@ class TestTransposeOpError(unittest.TestCase): ...@@ -149,6 +151,39 @@ class TestTransposeOpError(unittest.TestCase):
self.assertRaises(ValueError, test_each_elem_value_check) self.assertRaises(ValueError, test_each_elem_value_check)
class TestTransposeApi(unittest.TestCase):
def test_static_out(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='x', shape=[2, 3, 4], dtype='float32')
x_trans1 = paddle.transpose(x, perm=[1, 0, 2])
x_trans2 = paddle.transpose(x, perm=(2, 1, 0))
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
x_np = np.random.random([2, 3, 4]).astype("float32")
result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2])
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))
np.testing.assert_array_equal(result1, expected_result1)
np.testing.assert_array_equal(result2, expected_result2)
def test_dygraph_out(self):
# This is an old test before 2.0 API so we need to disable static
# to trigger dygraph
paddle.disable_static()
x = paddle.randn([2, 3, 4])
x_trans1 = paddle.transpose(x, perm=[1, 0, 2])
x_trans2 = paddle.transpose(x, perm=(2, 1, 0))
x_np = x.numpy()
expected_result1 = np.transpose(x_np, [1, 0, 2])
expected_result2 = np.transpose(x_np, (2, 1, 0))
np.testing.assert_array_equal(x_trans1.numpy(), expected_result1)
np.testing.assert_array_equal(x_trans2.numpy(), expected_result2)
# This is an old test before 2.0 API so we enable static again after
# dygraph test
paddle.enable_static()
class TestTAPI(unittest.TestCase): class TestTAPI(unittest.TestCase):
def test_out(self): def test_out(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册