From 4cbed3a3f6cc834191cfc2fba2347a12da33dc8d Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 8 Apr 2020 23:02:34 +0800 Subject: [PATCH] API (layers.data/fluid.data) error message enhancement (#23427) * Api (fluid.data/layers.data) error message enhancement. test=develop --- python/paddle/fluid/data.py | 9 ++- python/paddle/fluid/layers/io.py | 7 ++- .../paddle/fluid/tests/unittests/test_data.py | 57 +++++++++++++++++++ 3 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_data.py diff --git a/python/paddle/fluid/data.py b/python/paddle/fluid/data.py index 14333cae1ec..e0888c2c078 100644 --- a/python/paddle/fluid/data.py +++ b/python/paddle/fluid/data.py @@ -15,8 +15,9 @@ import numpy as np import six -from . import core -from .layer_helper import LayerHelper +from paddle.fluid import core +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.data_feeder import check_dtype, check_type __all__ = ['data'] @@ -99,6 +100,10 @@ def data(name, shape, dtype='float32', lod_level=0): """ helper = LayerHelper('data', **locals()) + + check_type(name, 'name', (six.binary_type, six.text_type), 'data') + check_type(shape, 'shape', (list, tuple), 'data') + shape = list(shape) for i in six.moves.range(len(shape)): if shape[i] is None: diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index dd86417e6cb..1b9d079448b 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -31,6 +31,7 @@ from ..layer_helper import LayerHelper from ..unique_name import generate as unique_name from ..transpiler.distribute_transpiler import DistributedMode import logging +from ..data_feeder import check_dtype, check_type __all__ = [ 'data', 'read_file', 'double_buffer', 'py_reader', @@ -73,7 +74,7 @@ def data(name, Args: name(str): The name/alias of the variable, see :ref:`api_guide_Name` for more details. - shape(list): Tuple declaring the shape. If :code:`append_batch_size` is + shape(list|tuple): Tuple declaring the shape. If :code:`append_batch_size` is True and there is no -1 inside :code:`shape`, it should be considered as the shape of the each sample. Otherwise, it should be considered as the shape of the batched data. @@ -107,6 +108,10 @@ def data(name, data = fluid.layers.data(name='x', shape=[784], dtype='float32') """ helper = LayerHelper('data', **locals()) + + check_type(name, 'name', (six.binary_type, six.text_type), 'data') + check_type(shape, 'shape', (list, tuple), 'data') + shape = list(shape) for i in six.moves.range(len(shape)): if shape[i] is None: diff --git a/python/paddle/fluid/tests/unittests/test_data.py b/python/paddle/fluid/tests/unittests/test_data.py new file mode 100644 index 00000000000..22dc72048e4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_data.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +from paddle.fluid import Program, program_guard + + +class TestApiDataError(unittest.TestCase): + def test_fluid_data(self): + with program_guard(Program(), Program()): + + # 1. The type of 'name' in fluid.data must be str. + def test_name_type(): + fluid.data(name=1, shape=[2, 25], dtype="bool") + + self.assertRaises(TypeError, test_name_type) + + # 2. The type of 'shape' in fluid.data must be list or tuple. + def test_shape_type(): + fluid.data(name='data1', shape=2, dtype="bool") + + self.assertRaises(TypeError, test_shape_type) + + def test_layers_data(self): + with program_guard(Program(), Program()): + + # 1. The type of 'name' in layers.data must be str. + def test_name_type(): + layers.data(name=1, shape=[2, 25], dtype="bool") + + self.assertRaises(TypeError, test_name_type) + + # 2. The type of 'shape' in layers.data must be list or tuple. + def test_shape_type(): + layers.data(name='data1', shape=2, dtype="bool") + + self.assertRaises(TypeError, test_shape_type) + + +if __name__ == "__main__": + unittest.main() -- GitLab