diff --git a/python/paddle/fluid/tests/unittests/test_tile_op.py b/python/paddle/fluid/tests/unittests/test_tile_op.py index 52562a6eb6f99596935069f83f793802bae6b4b8..b95e8c4af7fbc04cca01675872dca74fbd12266c 100644 --- a/python/paddle/fluid/tests/unittests/test_tile_op.py +++ b/python/paddle/fluid/tests/unittests/test_tile_op.py @@ -366,6 +366,20 @@ class TestTileAPI_ZeroDim(unittest.TestCase): paddle.enable_static() +class Testfp16TileOp(unittest.TestCase): + def testfp16(self): + input_x = (np.random.random([1, 2, 3])).astype('float16') + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[1, 2, 3], dtype='float16') + repeat_times = [2, 2] + out = paddle.tile(x, repeat_times=repeat_times) + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': input_x}, fetch_list=[out]) + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 2b4caeff7e56c2615c2b3e03ac07cdceca43b003..91fa35c05981443cdfaa79ce0ecdcaaa3a2f136e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3084,7 +3084,7 @@ def tile(x, repeat_times, name=None): Both the number of dimensions of ``x`` and the number of elements in ``repeat_times`` should be less than or equal to 6. Args: - x (Tensor): The input tensor, its data type should be bool, float32, float64, int32 or int64. + x (Tensor): The input tensor, its data type should be bool, float16, float32, float64, int32 or int64. repeat_times (list|tuple|Tensor): The number of repeating times. If repeat_times is a list or tuple, all its elements should be integers or 1-D Tensors with the data type int32. If repeat_times is a Tensor, it should be an 1-D Tensor with the data type int32. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -3145,7 +3145,10 @@ def tile(x, repeat_times, name=None): ), 'Elements in repeat_times must be 1-D Tensors or integers.' check_variable_and_dtype( - x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'tile' + x, + 'x', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'tile', ) if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient: raise ValueError(