未验证 提交 9a76f3f9 编写于 作者: W wangchaochaohu 提交者: GitHub

Fill constant error message fix (#20075)

* fix the constant error message test=develop

* fix typo test=develop

* fix typo test=develop

* fix code style test=develop

* fix comment and bugs test=develop

* fix the bug test=develop

* fix and add unittest test=develop

* fix the typo test=develop

* add support for the fill_constant op test=develop

* add test for ci coverage test=develop
上级 e8673668
......@@ -87,4 +87,5 @@ REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>);
......@@ -27,7 +27,19 @@ __all__ = ['DataFeeder']
def convert_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
if isinstance(dtype, str):
if dtype in [
'float32', 'int64', 'float64', 'float16', 'int32', 'uint8',
'bool'
]:
return dtype
else:
raise ValueError(
"dtype must be any of [bool, int32, float32, int64, "
"float64, uint8]")
elif dtype == core.VarDesc.VarType.BOOL:
return 'bool'
elif dtype == core.VarDesc.VarType.FP32:
return 'float32'
elif dtype == core.VarDesc.VarType.INT64:
return 'int64'
......@@ -40,7 +52,7 @@ def convert_dtype(dtype):
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
else:
raise ValueError("dtype must be any of [int32, float32, int64, "
raise ValueError("dtype must be any of [bool,int32, float32, int64, "
"float64, uint8]")
......
......@@ -21,6 +21,7 @@ from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc
from .layer_function_generator import templatedoc
from ..data_feeder import convert_dtype
import numpy
__all__ = [
......@@ -397,8 +398,21 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"""
helper = LayerHelper("fill_constant", **locals())
if convert_dtype(dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The create data type in fill_constant must be one of 'bool', float16, float32,"
"float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
if not (convert_dtype(dtype) == convert_dtype(out.dtype)):
raise TypeError(
"The create data type in op must be same with out type"
"but received %s and out dtype %s." % (convert_dtype(
(dtype), convert_dtype(out.dtype))))
helper.append_op(
type='fill_constant',
inputs={},
......
......@@ -20,6 +20,8 @@ from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestFillConstantOp1(OpTest):
......@@ -104,5 +106,41 @@ class TestFillConstantOpWithSelectedRows(OpTest):
self.check_with_place(place)
class TestFillConstantOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
#for ci coverage
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
self.assertRaises(
ValueError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='uint4')
self.assertRaises(
ValueError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='int16',
out=x1)
# The input dtype of fill_constant must be one of bool, float16,
#float32, float64, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='uint8')
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='float64',
out=x2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册