diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index a65a044895a4ff83b72560dc30efe93d5cec062d..ebe1f6cccf93db4ac4489f08be37512562d1f998 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -302,7 +302,7 @@ class GradNodeBase { // Gradient Hooks // Customer may register a list of hooks which will be called in order during // backward - // Each entry consists one pair of + // Each entry consists of one pair of // >> std::map #include #include // NOLINT // for call_once +#include #include #include #include @@ -346,6 +347,52 @@ bool IsCompiledWithDIST() { #endif } +struct iinfo { + int64_t min, max; + int bits; + std::string dtype; + + explicit iinfo(const framework::proto::VarType::Type &type) { + switch (type) { + case framework::proto::VarType::INT16: + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); + bits = 16; + dtype = "int16"; + break; + case framework::proto::VarType::INT32: + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); + bits = 32; + dtype = "int32"; + break; + case framework::proto::VarType::INT64: + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); + bits = 64; + dtype = "int64"; + break; + case framework::proto::VarType::INT8: + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); + bits = 8; + dtype = "int8"; + break; + case framework::proto::VarType::UINT8: + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); + bits = 8; + dtype = "uint8"; + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "the argument of paddle.iinfo can only be paddle.int8, " + "paddle.int16, paddle.int32, paddle.int64, or paddle.uint8")); + break; + } + } +}; + static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) { // NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name // is not inside obj, but it would also set the error flag of Python. @@ -555,6 +602,21 @@ PYBIND11_MODULE(core_noavx, m) { BindException(&m); + py::class_(m, "iinfo") + .def(py::init()) + .def_readonly("min", &iinfo::min) + .def_readonly("max", &iinfo::max) + .def_readonly("bits", &iinfo::bits) + .def_readonly("dtype", &iinfo::dtype) + .def("__repr__", [](const iinfo &a) { + std::ostringstream oss; + oss << "paddle.iinfo(min=" << a.min; + oss << ", max=" << a.max; + oss << ", bits=" << a.bits; + oss << ", dtype=" << a.dtype << ")"; + return oss.str(); + }); + m.def("set_num_threads", &platform::SetNumThreads); m.def("disable_signal_handler", &DisableSignalHandler); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b39f4161eee97801c9a4335eb91bc78d04195572..dc55260f2ce757e369d0d0bbf5ba21ca74fe3c16 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -38,6 +38,7 @@ from .framework import in_dynamic_mode # noqa: F401 from .fluid.dataset import * # noqa: F401 from .fluid.lazy_init import LazyGuard # noqa: F401 +from .framework.dtype import iinfo # noqa: F401 from .framework.dtype import dtype as dtype # noqa: F401 from .framework.dtype import uint8 # noqa: F401 from .framework.dtype import int8 # noqa: F401 @@ -386,6 +387,7 @@ if is_compiled_with_cinn(): disable_static() __all__ = [ # noqa + 'iinfo', 'dtype', 'uint8', 'int8', diff --git a/python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py b/python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py new file mode 100644 index 0000000000000000000000000000000000000000..9debbccdb3d7edd7fd7b8d4a434cd991888c08fd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np + + +class TestIInfoAndFInfoAPI(unittest.TestCase): + + def test_invalid_input(self): + for dtype in [ + paddle.float16, paddle.float32, paddle.float64, paddle.bfloat16, + paddle.complex64, paddle.complex128, paddle.bool + ]: + with self.assertRaises(ValueError): + _ = paddle.iinfo(dtype) + + def test_iinfo(self): + for paddle_dtype, np_dtype in [(paddle.int64, np.int64), + (paddle.int32, np.int32), + (paddle.int16, np.int16), + (paddle.int8, np.int8), + (paddle.uint8, np.uint8)]: + xinfo = paddle.iinfo(paddle_dtype) + xninfo = np.iinfo(np_dtype) + self.assertEqual(xinfo.bits, xninfo.bits) + self.assertEqual(xinfo.max, xninfo.max) + self.assertEqual(xinfo.min, xninfo.min) + self.assertEqual(xinfo.dtype, xninfo.dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/framework/dtype.py b/python/paddle/framework/dtype.py index 56a95f48b5f9b8aeeb9c70f739c37650c0d04843..6abc8e6e1aa9a7af641bb3a125e27b7a4943d688 100644 --- a/python/paddle/framework/dtype.py +++ b/python/paddle/framework/dtype.py @@ -13,6 +13,7 @@ # limitations under the License. from ..fluid.core import VarDesc +from ..fluid.core import iinfo as core_iinfo dtype = VarDesc.VarType dtype.__qualname__ = "dtype" @@ -34,4 +35,37 @@ complex128 = VarDesc.VarType.COMPLEX128 bool = VarDesc.VarType.BOOL -__all__ = [] + +def iinfo(dtype): + """ + + paddle.iinfo is a function that returns an object that represents the numerical properties of + an integer paddle.dtype. + This is similar to `numpy.iinfo `_. + + Args: + dtype(paddle.dtype): One of paddle.uint8, paddle.int8, paddle.int16, paddle.int32, and paddle.int64. + + Returns: + An iinfo object, which has the following 4 attributes: + + - min: int, The smallest representable integer number. + - max: int, The largest representable integer number. + - bits: int, The number of bits occupied by the type. + - dtype: str, The string name of the argument dtype. + + Examples: + .. code-block:: python + + import paddle + + iinfo_uint8 = paddle.iinfo(paddle.uint8) + print(iinfo_uint8) + # paddle.iinfo(min=0, max=255, bits=8, dtype=uint8) + print(iinfo_uint8.min) # 0 + print(iinfo_uint8.max) # 255 + print(iinfo_uint8.bits) # 8 + print(iinfo_uint8.dtype) # uint8 + + """ + return core_iinfo(dtype)