提交 6cff0acc 编写于 作者: L liym27 提交者: Aurelius84

[cherry-pick]fix reshape input(x) error check on float16. test=release/1.6 (#20529) (#20551)

上级 4d36c1c7
......@@ -8585,9 +8585,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
"The type of 'x' in reshape must be Variable, but received %s." %
(type(x)))
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in reshape only support float16 in GPU now.")
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of 'x' in reshape must be float32, float64, int32 or int64, "
"The data type of 'x' in reshape must be float16, float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not isinstance(shape, (list, tuple, Variable)):
......
......@@ -238,17 +238,27 @@ class TestReshapeOpError(OpTest):
self.assertRaises(TypeError, test_x_type)
# The x dtype of reshape_op must be float32, float64, int32 or int64.
# The x dtype of reshape_op must be float16, float32, float64, int32 or int64.
def test_x_dtype():
x2 = fluid.layers.data(
name="x2",
shape=[2, 25],
append_batch_size=False,
dtype="float16")
dtype="bool")
fluid.layers.reshape(x2, shape=[2, 5, 5])
self.assertRaises(TypeError, test_x_dtype)
def test_x_dtype_float16():
x_float16 = fluid.layers.data(
name="x_float16",
shape=[2, 25],
append_batch_size=False,
dtype="float16")
fluid.layers.reshape(x_float16, shape=[2, 5, 5])
test_x_dtype_float16()
x3 = fluid.layers.data(
name="x3",
shape=[2, 25],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册