diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 760f5ce58bf268919667db1c0623ce2096d2bdae..3ac43df872e377e96f49d6852744febde219d69d 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5459,7 +5459,7 @@ def transpose(x, perm, name=None): Args: x (Variable): The input Tensor. It is a N-D Tensor of data types float32, float64, int32. - perm (list): Permute the input according to the data of perm. + perm (list|tuple): Permute the input according to the data of perm. name (str): The name of this layer. It is optional. Returns: @@ -5492,14 +5492,12 @@ def transpose(x, perm, name=None): .. code-block:: python - # use append_batch_size=False to avoid prepending extra - # batch size in shape - import paddle.fluid as fluid - x = fluid.layers.data(name='x', shape=[2, 3, 4], - dtype='float32', append_batch_size=False) - x_transposed = fluid.layers.transpose(x, perm=[1, 0, 2]) - print x_transposed.shape - #(3L, 2L, 4L) + import paddle + + x = paddle.randn([2, 3, 4]) + x_transposed = paddle.transpose(x, perm=[1, 0, 2]) + print(x_transposed.shape) + # [3L, 2L, 4L] """ if in_dygraph_mode(): @@ -5509,8 +5507,9 @@ def transpose(x, perm, name=None): check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'transpose') - check_type(perm, 'perm', list, 'transpose') - + check_type(perm, 'perm', (list, tuple), 'transpose') + if isinstance(perm, tuple): + perm = list(perm) if len(perm) != len(x.shape): raise ValueError( "Input(perm) is the permutation of dimensions of Input(x), " diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index 56333211469db5705d04cc5ca253bf01679190a5..f72df8cbe4640941d014b310325a8bb56d8af65f 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -21,6 +21,7 @@ import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard +paddle.enable_static() class TestTransposeOp(OpTest): def setUp(self): @@ -113,6 +114,7 @@ class TestCase9(TestTransposeOp): class TestTransposeOpError(unittest.TestCase): def test_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float64') @@ -149,6 +151,39 @@ class TestTransposeOpError(unittest.TestCase): self.assertRaises(ValueError, test_each_elem_value_check) +class TestTransposeApi(unittest.TestCase): + def test_static_out(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name='x', shape=[2, 3, 4], dtype='float32') + x_trans1 = paddle.transpose(x, perm=[1, 0, 2]) + x_trans2 = paddle.transpose(x, perm=(2, 1, 0)) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + x_np = np.random.random([2, 3, 4]).astype("float32") + result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2]) + expected_result1 = np.transpose(x_np, [1, 0, 2]) + expected_result2 = np.transpose(x_np, (2, 1, 0)) + + np.testing.assert_array_equal(result1, expected_result1) + np.testing.assert_array_equal(result2, expected_result2) + + def test_dygraph_out(self): + # This is an old test before 2.0 API so we need to disable static + # to trigger dygraph + paddle.disable_static() + x = paddle.randn([2, 3, 4]) + x_trans1 = paddle.transpose(x, perm=[1, 0, 2]) + x_trans2 = paddle.transpose(x, perm=(2, 1, 0)) + x_np = x.numpy() + expected_result1 = np.transpose(x_np, [1, 0, 2]) + expected_result2 = np.transpose(x_np, (2, 1, 0)) + + np.testing.assert_array_equal(x_trans1.numpy(), expected_result1) + np.testing.assert_array_equal(x_trans2.numpy(), expected_result2) + # This is an old test before 2.0 API so we enable static again after + # dygraph test + paddle.enable_static() class TestTAPI(unittest.TestCase): def test_out(self):