From 9a76f3f9166bf9aa465398ff24b0bb9a6ccc21a3 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 3 Oct 2019 16:58:04 +0800 Subject: [PATCH] Fill constant error message fix (#20075) * fix the constant error message test=develop * fix typo test=develop * fix typo test=develop * fix code style test=develop * fix comment and bugs test=develop * fix the bug test=develop * fix and add unittest test=develop * fix the typo test=develop * add support for the fill_constant op test=develop * add test for ci coverage test=develop --- paddle/fluid/operators/fill_constant_op.cc | 1 + python/paddle/fluid/data_feeder.py | 16 +++++++- python/paddle/fluid/layers/tensor.py | 14 +++++++ .../tests/unittests/test_fill_constant_op.py | 38 +++++++++++++++++++ 4 files changed, 67 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index cf2f4776cf2..b11a89c7385 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -87,4 +87,5 @@ REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel); diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 3f9c69f120e..b377b32636c 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -27,7 +27,19 @@ __all__ = ['DataFeeder'] def convert_dtype(dtype): - if dtype == core.VarDesc.VarType.FP32: + if isinstance(dtype, str): + if dtype in [ + 'float32', 'int64', 'float64', 'float16', 'int32', 'uint8', + 'bool' + ]: + return dtype + else: + raise ValueError( + "dtype must be any of [bool, int32, float32, int64, " + "float64, uint8]") + elif dtype == core.VarDesc.VarType.BOOL: + return 'bool' + elif dtype == core.VarDesc.VarType.FP32: return 'float32' elif dtype == core.VarDesc.VarType.INT64: return 'int64' @@ -40,7 +52,7 @@ def convert_dtype(dtype): elif dtype == core.VarDesc.VarType.UINT8: return 'uint8' else: - raise ValueError("dtype must be any of [int32, float32, int64, " + raise ValueError("dtype must be any of [bool,int32, float32, int64, " "float64, uint8]") diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index b0838227f0d..3a5a4321657 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -21,6 +21,7 @@ from ..framework import Variable from ..initializer import Constant, force_init_on_cpu from ..core import VarDesc from .layer_function_generator import templatedoc +from ..data_feeder import convert_dtype import numpy __all__ = [ @@ -397,8 +398,21 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): """ helper = LayerHelper("fill_constant", **locals()) + if convert_dtype(dtype) not in [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64' + ]: + raise TypeError( + "The create data type in fill_constant must be one of 'bool', float16, float32," + "float64, int32 or int64, but received %s." % convert_dtype( + (dtype))) if out is None: out = helper.create_variable_for_type_inference(dtype=dtype) + else: + if not (convert_dtype(dtype) == convert_dtype(out.dtype)): + raise TypeError( + "The create data type in op must be same with out type" + "but received %s and out dtype %s." % (convert_dtype( + (dtype), convert_dtype(out.dtype)))) helper.append_op( type='fill_constant', inputs={}, diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index e22bd09ed06..9401007643e 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -20,6 +20,8 @@ from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard class TestFillConstantOp1(OpTest): @@ -104,5 +106,41 @@ class TestFillConstantOpWithSelectedRows(OpTest): self.check_with_place(place) +class TestFillConstantOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + #for ci coverage + x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16") + self.assertRaises( + ValueError, + fluid.layers.fill_constant, + shape=[1], + value=5, + dtype='uint4') + self.assertRaises( + ValueError, + fluid.layers.fill_constant, + shape=[1], + value=5, + dtype='int16', + out=x1) + # The input dtype of fill_constant must be one of bool, float16, + #float32, float64, int32 or int64 + x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32") + self.assertRaises( + TypeError, + fluid.layers.fill_constant, + shape=[1], + value=5, + dtype='uint8') + self.assertRaises( + TypeError, + fluid.layers.fill_constant, + shape=[1], + value=5, + dtype='float64', + out=x2) + + if __name__ == "__main__": unittest.main() -- GitLab