From 6cff0accda945b042aeb484047f47cc9c855fd6b Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sun, 13 Oct 2019 10:58:57 +0800 Subject: [PATCH] [cherry-pick]fix reshape input(x) error check on float16. test=release/1.6 (#20529) (#20551) --- python/paddle/fluid/layers/nn.py | 10 ++++++++-- .../fluid/tests/unittests/test_reshape_op.py | 14 ++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3387a66bff6..095425b02f0 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8585,9 +8585,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): "The type of 'x' in reshape must be Variable, but received %s." % (type(x))) - if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']: + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in reshape only support float16 in GPU now.") + + if convert_dtype(x.dtype) not in [ + 'float16', 'float32', 'float64', 'int32', 'int64' + ]: raise TypeError( - "The data type of 'x' in reshape must be float32, float64, int32 or int64, " + "The data type of 'x' in reshape must be float16, float32, float64, int32 or int64, " "but received %s." % (convert_dtype(x.dtype))) if not isinstance(shape, (list, tuple, Variable)): diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index ea43b6b603b..931d6caac55 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -238,17 +238,27 @@ class TestReshapeOpError(OpTest): self.assertRaises(TypeError, test_x_type) - # The x dtype of reshape_op must be float32, float64, int32 or int64. + # The x dtype of reshape_op must be float16, float32, float64, int32 or int64. def test_x_dtype(): x2 = fluid.layers.data( name="x2", shape=[2, 25], append_batch_size=False, - dtype="float16") + dtype="bool") fluid.layers.reshape(x2, shape=[2, 5, 5]) self.assertRaises(TypeError, test_x_dtype) + def test_x_dtype_float16(): + x_float16 = fluid.layers.data( + name="x_float16", + shape=[2, 25], + append_batch_size=False, + dtype="float16") + fluid.layers.reshape(x_float16, shape=[2, 5, 5]) + + test_x_dtype_float16() + x3 = fluid.layers.data( name="x3", shape=[2, 25], -- GitLab