From f567214b6ad2c405fb1ab2a8db21525d20c393ff Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Sun, 13 Oct 2019 16:10:44 +0800 Subject: [PATCH] Fix the error message of zeros op (#20476) * fix the error message of zeros op test=develop * Fix unittest of zeros op test=develop --- python/paddle/fluid/layers/tensor.py | 7 ++++ .../fluid/tests/unittests/test_zeros_op.py | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_zeros_op.py diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 504e7a3e41d..b42dddce55e 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -919,6 +919,13 @@ def zeros(shape, dtype, force_cpu=False): import paddle.fluid as fluid data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] """ + if convert_dtype(dtype) not in [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64' + ]: + raise TypeError( + "The create data type in zeros must be one of bool, float16, float32," + " float64, int32 or int64, but received %s." % convert_dtype( + (dtype))) return fill_constant(value=0.0, **locals()) diff --git a/python/paddle/fluid/tests/unittests/test_zeros_op.py b/python/paddle/fluid/tests/unittests/test_zeros_op.py new file mode 100644 index 00000000000..62ed6b56c6d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_zeros_op.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestZerosOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + # The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64. + x1 = fluid.layers.data(name='x1', shape=[4], dtype="int8") + self.assertRaises(TypeError, fluid.layers.zeros, x1) + x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.zeros, x2) + + +if __name__ == "__main__": + unittest.main() -- GitLab