提交 01209b51 编写于 作者: Z zhupengyang 提交者: hong19860320

add input type and dtype check for cast_op (#20070)

* add input type and dtype check for cast_op
test=develop

* fix annotation
test=develop

* support more data type

test=develop

* fix bug for fill_constant's error type

test=develop

* improve converage

test=develop

* improve converage

test=develop
上级 67fcb0c9
...@@ -29,31 +29,32 @@ __all__ = ['DataFeeder'] ...@@ -29,31 +29,32 @@ __all__ = ['DataFeeder']
def convert_dtype(dtype): def convert_dtype(dtype):
if isinstance(dtype, str): if isinstance(dtype, str):
if dtype in [ if dtype in [
'float32', 'int64', 'float64', 'float16', 'int32', 'uint8', 'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
'bool' 'int32', 'int64', 'uint8'
]: ]:
return dtype return dtype
else:
raise ValueError(
"dtype must be any of [bool, int32, float32, int64, "
"float64, uint8]")
elif dtype == core.VarDesc.VarType.BOOL:
return 'bool'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.FP64:
return 'float64'
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.INT32:
return 'int32'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
else: else:
raise ValueError("dtype must be any of [bool,int32, float32, int64, " if dtype == core.VarDesc.VarType.BOOL:
"float64, uint8]") return 'bool'
elif dtype == core.VarDesc.VarType.FP16:
return 'float16'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.FP64:
return 'float64'
elif dtype == core.VarDesc.VarType.INT8:
return 'int8'
elif dtype == core.VarDesc.VarType.INT16:
return 'int16'
elif dtype == core.VarDesc.VarType.INT32:
return 'int32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
raise ValueError(
"dtype must be any of [bool, float16, float32, float64, int8, int16, "
"int32, int64, uint8]")
class DataToLoDTensorConverter(object): class DataToLoDTensorConverter(object):
......
...@@ -192,6 +192,17 @@ def cast(x, dtype): ...@@ -192,6 +192,17 @@ def cast(x, dtype):
# [ 0 4]] int32 # [ 0 4]] int32
""" """
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
if not isinstance(x, Variable):
raise TypeError(
"The type of 'x' in cast must be Variable, but received %s" %
(type(x)))
if convert_dtype(x.dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'
]:
raise TypeError(
"The data type of 'x' in cast must be one of [bool, float16, float32, float64, int32, int64, uint8], but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='cast', type='cast',
......
...@@ -18,6 +18,8 @@ import op_test ...@@ -18,6 +18,8 @@ import op_test
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestCastOp1(op_test.OpTest): class TestCastOp1(op_test.OpTest):
...@@ -69,5 +71,19 @@ class TestCastOp3(op_test.OpTest): ...@@ -69,5 +71,19 @@ class TestCastOp3(op_test.OpTest):
self.check_output(atol=1e-3) self.check_output(atol=1e-3)
class TestCastOpError(op_test.OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of cast_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be bool, float16, float32, float64, int32, int64, uint8.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int8')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
x3 = fluid.layers.data(name='x3', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x3, 'int32')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -230,7 +230,7 @@ class TestFillConstantOpError(OpTest): ...@@ -230,7 +230,7 @@ class TestFillConstantOpError(OpTest):
value=5, value=5,
dtype='uint4') dtype='uint4')
self.assertRaises( self.assertRaises(
ValueError, TypeError,
fluid.layers.fill_constant, fluid.layers.fill_constant,
shape=[1], shape=[1],
value=5, value=5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册