diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f0c955bb562b10afc726d1fa4060d81c05622451..bdbd4cf723bc207a2105ab4cb40e5ada6d71530b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8524,9 +8524,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): "The type of 'x' in reshape must be Variable, but received %s." % (type(x))) - if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']: + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in reshape only support float16 in GPU now.") + + if convert_dtype(x.dtype) not in [ + 'float16', 'float32', 'float64', 'int32', 'int64' + ]: raise TypeError( - "The data type of 'x' in reshape must be float32, float64, int32 or int64, " + "The data type of 'x' in reshape must be float16, float32, float64, int32 or int64, " "but received %s." % (convert_dtype(x.dtype))) if not isinstance(shape, (list, tuple, Variable)): diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index ea43b6b603bae54a7525a60ddc468f1d5f951bfe..931d6caac55e141cb72ab4910c93f84d4041328f 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -238,17 +238,27 @@ class TestReshapeOpError(OpTest): self.assertRaises(TypeError, test_x_type) - # The x dtype of reshape_op must be float32, float64, int32 or int64. + # The x dtype of reshape_op must be float16, float32, float64, int32 or int64. def test_x_dtype(): x2 = fluid.layers.data( name="x2", shape=[2, 25], append_batch_size=False, - dtype="float16") + dtype="bool") fluid.layers.reshape(x2, shape=[2, 5, 5]) self.assertRaises(TypeError, test_x_dtype) + def test_x_dtype_float16(): + x_float16 = fluid.layers.data( + name="x_float16", + shape=[2, 25], + append_batch_size=False, + dtype="float16") + fluid.layers.reshape(x_float16, shape=[2, 5, 5]) + + test_x_dtype_float16() + x3 = fluid.layers.data( name="x3", shape=[2, 25],