提交 01209b51 编写于 作者: Z zhupengyang 提交者: hong19860320

add input type and dtype check for cast_op (#20070)

* add input type and dtype check for cast_op
test=develop

* fix annotation
test=develop

* support more data type

test=develop

* fix bug for fill_constant's error type

test=develop

* improve converage

test=develop

* improve converage

test=develop
上级 67fcb0c9
......@@ -29,31 +29,32 @@ __all__ = ['DataFeeder']
def convert_dtype(dtype):
if isinstance(dtype, str):
if dtype in [
'float32', 'int64', 'float64', 'float16', 'int32', 'uint8',
'bool'
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
'int32', 'int64', 'uint8'
]:
return dtype
else:
raise ValueError(
"dtype must be any of [bool, int32, float32, int64, "
"float64, uint8]")
elif dtype == core.VarDesc.VarType.BOOL:
if dtype == core.VarDesc.VarType.BOOL:
return 'bool'
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.FP64:
return 'float64'
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.INT8:
return 'int8'
elif dtype == core.VarDesc.VarType.INT16:
return 'int16'
elif dtype == core.VarDesc.VarType.INT32:
return 'int32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
else:
raise ValueError("dtype must be any of [bool,int32, float32, int64, "
"float64, uint8]")
raise ValueError(
"dtype must be any of [bool, float16, float32, float64, int8, int16, "
"int32, int64, uint8]")
class DataToLoDTensorConverter(object):
......
......@@ -192,6 +192,17 @@ def cast(x, dtype):
# [ 0 4]] int32
"""
helper = LayerHelper('cast', **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in cast must be Variable, but received %s" %
(type(x)))
if convert_dtype(x.dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'
]:
raise TypeError(
"The data type of 'x' in cast must be one of [bool, float16, float32, float64, int32, int64, uint8], but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='cast',
......
......@@ -18,6 +18,8 @@ import op_test
import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestCastOp1(op_test.OpTest):
......@@ -69,5 +71,19 @@ class TestCastOp3(op_test.OpTest):
self.check_output(atol=1e-3)
class TestCastOpError(op_test.OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of cast_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be bool, float16, float32, float64, int32, int64, uint8.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int8')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
x3 = fluid.layers.data(name='x3', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x3, 'int32')
if __name__ == '__main__':
unittest.main()
......@@ -230,7 +230,7 @@ class TestFillConstantOpError(OpTest):
value=5,
dtype='uint4')
self.assertRaises(
ValueError,
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册