From 7ffbf7e337065a4907dd7459b1356db2ce5edf29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 27 Feb 2023 15:41:06 +0800 Subject: [PATCH] suppot fp16 in flatten (#50906) --- .../fluid/tests/unittests/test_flatten_op.py | 28 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_flatten_op.py b/python/paddle/fluid/tests/unittests/test_flatten_op.py index 7753f2d90e..f1f049c588 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_op.py @@ -17,6 +17,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle + class TestFlattenOp(OpTest): def setUp(self): @@ -64,5 +66,31 @@ class TestFlattenOpSixDims(TestFlattenOp): self.new_shape = (36, 16) +class TestFlattenOpFP16(unittest.TestCase): + def test_fp16_with_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([12, 14]).astype("float16") + x = paddle.static.data( + name="x", shape=[12, 14], dtype="float16" + ) + + y = paddle.flatten(x) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + }, + fetch_list=[y], + ) + + assert np.array_equal(res[0].shape, [12 * 14]) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 8f257725e8..7d82a83e2a 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1510,7 +1510,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): Out.shape = (3 * 100 * 100 * 4) Args: - x (Tensor): A tensor of number of dimentions >= axis. A tensor with data type float32, + x (Tensor): A tensor of number of dimentions >= axis. A tensor with data type float16, float32, float64, int8, int32, int64, uint8. start_axis (int): the start axis to flatten stop_axis (int): the stop axis to flatten -- GitLab