diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 0d7ceb22afb4dfd005ebbbc039bd02148d6bd362..01e07d2e715232a5ef41f9c8acf33834cb93bf64 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -902,6 +902,13 @@ def zeros(shape, dtype, force_cpu=False): import paddle.fluid as fluid data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] """ + if convert_dtype(dtype) not in [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64' + ]: + raise TypeError( + "The create data type in zeros must be one of bool, float16, float32," + " float64, int32 or int64, but received %s." % convert_dtype( + (dtype))) return fill_constant(value=0.0, **locals()) diff --git a/python/paddle/fluid/tests/unittests/test_zeros_op.py b/python/paddle/fluid/tests/unittests/test_zeros_op.py new file mode 100644 index 0000000000000000000000000000000000000000..62ed6b56c6d6ca2c91da04856572b49832421595 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_zeros_op.py @@ -0,0 +1,38 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestZerosOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + # The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64. + x1 = fluid.layers.data(name='x1', shape=[4], dtype="int8") + self.assertRaises(TypeError, fluid.layers.zeros, x1) + x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.zeros, x2) + + +if __name__ == "__main__": + unittest.main()