未验证 提交 40a0a46b 编写于 作者: O OccupyMars2025 提交者: GitHub

[ Hackathon 3rd No.2 ] add paddle.iinfo (#45321)

上级 a642365e
......@@ -302,7 +302,7 @@ class GradNodeBase {
// Gradient Hooks
// Customer may register a list of hooks which will be called in order during
// backward
// Each entry consists one pair of
// Each entry consists of one pair of
// <hook_id, <out_rank, std::shared_ptr<TensorHook>>>
std::map<int64_t,
std::tuple<
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
......@@ -346,6 +347,52 @@ bool IsCompiledWithDIST() {
#endif
}
struct iinfo {
int64_t min, max;
int bits;
std::string dtype;
explicit iinfo(const framework::proto::VarType::Type &type) {
switch (type) {
case framework::proto::VarType::INT16:
min = std::numeric_limits<int16_t>::min();
max = std::numeric_limits<int16_t>::max();
bits = 16;
dtype = "int16";
break;
case framework::proto::VarType::INT32:
min = std::numeric_limits<int32_t>::min();
max = std::numeric_limits<int32_t>::max();
bits = 32;
dtype = "int32";
break;
case framework::proto::VarType::INT64:
min = std::numeric_limits<int64_t>::min();
max = std::numeric_limits<int64_t>::max();
bits = 64;
dtype = "int64";
break;
case framework::proto::VarType::INT8:
min = std::numeric_limits<int8_t>::min();
max = std::numeric_limits<int8_t>::max();
bits = 8;
dtype = "int8";
break;
case framework::proto::VarType::UINT8:
min = std::numeric_limits<uint8_t>::min();
max = std::numeric_limits<uint8_t>::max();
bits = 8;
dtype = "uint8";
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"the argument of paddle.iinfo can only be paddle.int8, "
"paddle.int16, paddle.int32, paddle.int64, or paddle.uint8"));
break;
}
}
};
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
......@@ -555,6 +602,21 @@ PYBIND11_MODULE(core_noavx, m) {
BindException(&m);
py::class_<iinfo>(m, "iinfo")
.def(py::init<const framework::proto::VarType::Type &>())
.def_readonly("min", &iinfo::min)
.def_readonly("max", &iinfo::max)
.def_readonly("bits", &iinfo::bits)
.def_readonly("dtype", &iinfo::dtype)
.def("__repr__", [](const iinfo &a) {
std::ostringstream oss;
oss << "paddle.iinfo(min=" << a.min;
oss << ", max=" << a.max;
oss << ", bits=" << a.bits;
oss << ", dtype=" << a.dtype << ")";
return oss.str();
});
m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler);
......
......@@ -38,6 +38,7 @@ from .framework import in_dynamic_mode # noqa: F401
from .fluid.dataset import * # noqa: F401
from .fluid.lazy_init import LazyGuard # noqa: F401
from .framework.dtype import iinfo # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401
from .framework.dtype import int8 # noqa: F401
......@@ -386,6 +387,7 @@ if is_compiled_with_cinn():
disable_static()
__all__ = [ # noqa
'iinfo',
'dtype',
'uint8',
'int8',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
class TestIInfoAndFInfoAPI(unittest.TestCase):
def test_invalid_input(self):
for dtype in [
paddle.float16, paddle.float32, paddle.float64, paddle.bfloat16,
paddle.complex64, paddle.complex128, paddle.bool
]:
with self.assertRaises(ValueError):
_ = paddle.iinfo(dtype)
def test_iinfo(self):
for paddle_dtype, np_dtype in [(paddle.int64, np.int64),
(paddle.int32, np.int32),
(paddle.int16, np.int16),
(paddle.int8, np.int8),
(paddle.uint8, np.uint8)]:
xinfo = paddle.iinfo(paddle_dtype)
xninfo = np.iinfo(np_dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertEqual(xinfo.max, xninfo.max)
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.dtype, xninfo.dtype)
if __name__ == '__main__':
unittest.main()
......@@ -13,6 +13,7 @@
# limitations under the License.
from ..fluid.core import VarDesc
from ..fluid.core import iinfo as core_iinfo
dtype = VarDesc.VarType
dtype.__qualname__ = "dtype"
......@@ -34,4 +35,37 @@ complex128 = VarDesc.VarType.COMPLEX128
bool = VarDesc.VarType.BOOL
__all__ = []
def iinfo(dtype):
"""
paddle.iinfo is a function that returns an object that represents the numerical properties of
an integer paddle.dtype.
This is similar to `numpy.iinfo <https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html#numpy-iinfo>`_.
Args:
dtype(paddle.dtype): One of paddle.uint8, paddle.int8, paddle.int16, paddle.int32, and paddle.int64.
Returns:
An iinfo object, which has the following 4 attributes:
- min: int, The smallest representable integer number.
- max: int, The largest representable integer number.
- bits: int, The number of bits occupied by the type.
- dtype: str, The string name of the argument dtype.
Examples:
.. code-block:: python
import paddle
iinfo_uint8 = paddle.iinfo(paddle.uint8)
print(iinfo_uint8)
# paddle.iinfo(min=0, max=255, bits=8, dtype=uint8)
print(iinfo_uint8.min) # 0
print(iinfo_uint8.max) # 255
print(iinfo_uint8.bits) # 8
print(iinfo_uint8.dtype) # uint8
"""
return core_iinfo(dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册