未验证 提交 f84f43ca 编写于 作者: M mhy 提交者: GitHub

【PaddlePaddle Hackathon 第四期】No1:为 Paddle 新增 finfo API (#50987)

上级 11a6149b
......@@ -71,6 +71,8 @@ limitations under the License. */
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/prim/utils/utils.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/memory/allocation/cuda_ipc_allocator.h"
......@@ -450,6 +452,73 @@ struct iinfo {
}
};
struct finfo {
int64_t bits;
double eps;
double min; // lowest()
double max;
double tiny;
double smallest_normal; // min()
double resolution;
std::string dtype;
explicit finfo(const framework::proto::VarType::Type &type) {
switch (type) {
case framework::proto::VarType::FP16:
eps = std::numeric_limits<paddle::platform::float16>::epsilon();
min = std::numeric_limits<paddle::platform::float16>::lowest();
max = std::numeric_limits<paddle::platform::float16>::max();
smallest_normal = std::numeric_limits<paddle::platform::float16>::min();
tiny = smallest_normal;
resolution = std::pow(
10, -std::numeric_limits<paddle::platform::float16>::digits10);
bits = 16;
dtype = "float16";
break;
case framework::proto::VarType::FP32:
case framework::proto::VarType::COMPLEX64:
eps = std::numeric_limits<float>::epsilon();
min = std::numeric_limits<float>::lowest();
max = std::numeric_limits<float>::max();
smallest_normal = std::numeric_limits<float>::min();
tiny = smallest_normal;
resolution = std::pow(10, -std::numeric_limits<float>::digits10);
bits = 32;
dtype = "float32";
break;
case framework::proto::VarType::FP64:
case framework::proto::VarType::COMPLEX128:
eps = std::numeric_limits<double>::epsilon();
min = std::numeric_limits<double>::lowest();
max = std::numeric_limits<double>::max();
smallest_normal = std::numeric_limits<double>::min();
tiny = smallest_normal;
resolution = std::pow(10, -std::numeric_limits<double>::digits10);
bits = 64;
dtype = "float64";
break;
case framework::proto::VarType::BF16:
eps = std::numeric_limits<paddle::platform::bfloat16>::epsilon();
min = std::numeric_limits<paddle::platform::bfloat16>::lowest();
max = std::numeric_limits<paddle::platform::bfloat16>::max();
smallest_normal =
std::numeric_limits<paddle::platform::bfloat16>::min();
tiny = smallest_normal;
resolution = std::pow(
10, -std::numeric_limits<paddle::platform::bfloat16>::digits10);
bits = 16;
dtype = "bfloat16";
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"the argument of paddle.finfo can only be paddle.float32, "
"paddle.float64, paddle.float16, paddle.bfloat16"
"paddle.complex64, or paddle.complex128"));
break;
}
}
};
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
......@@ -671,6 +740,29 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str();
});
py::class_<finfo>(m, "finfo")
.def(py::init<const framework::proto::VarType::Type &>())
.def_readonly("min", &finfo::min)
.def_readonly("max", &finfo::max)
.def_readonly("bits", &finfo::bits)
.def_readonly("eps", &finfo::eps)
.def_readonly("resolution", &finfo::resolution)
.def_readonly("smallest_normal", &finfo::smallest_normal)
.def_readonly("tiny", &finfo::tiny)
.def_readonly("dtype", &finfo::dtype)
.def("__repr__", [](const finfo &a) {
std::ostringstream oss;
oss << "paddle.finfo(min=" << a.min;
oss << ", max=" << a.max;
oss << ", eps=" << a.eps;
oss << ", resolution=" << a.resolution;
oss << ", smallest_normal=" << a.smallest_normal;
oss << ", tiny=" << a.tiny;
oss << ", bits=" << a.bits;
oss << ", dtype=" << a.dtype << ")";
return oss.str();
});
m.def("__set_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
......
......@@ -373,7 +373,7 @@ struct numeric_limits<phi::dtype::bfloat16> {
static const bool tinyness_before = false;
HOSTDEVICE static phi::dtype::bfloat16(min)() {
return phi::dtype::raw_uint16_to_bfloat16(0x007f);
return phi::dtype::raw_uint16_to_bfloat16(0x0080);
}
HOSTDEVICE static phi::dtype::bfloat16 lowest() {
return phi::dtype::raw_uint16_to_bfloat16(0xff7f);
......@@ -382,7 +382,7 @@ struct numeric_limits<phi::dtype::bfloat16> {
return phi::dtype::raw_uint16_to_bfloat16(0x7f7f);
}
HOSTDEVICE static phi::dtype::bfloat16 epsilon() {
return phi::dtype::raw_uint16_to_bfloat16(0x3400);
return phi::dtype::raw_uint16_to_bfloat16(0x3C00);
}
HOSTDEVICE static phi::dtype::bfloat16 round_error() {
return phi::dtype::bfloat16(0.5);
......
......@@ -1064,7 +1064,7 @@ struct numeric_limits<phi::dtype::float16> {
return phi::dtype::raw_uint16_to_float16(0x7bff);
}
HOSTDEVICE static phi::dtype::float16 epsilon() {
return phi::dtype::raw_uint16_to_float16(0x0800);
return phi::dtype::raw_uint16_to_float16(0x1400);
}
HOSTDEVICE static phi::dtype::float16 round_error() {
return phi::dtype::float16(0.5);
......
......@@ -41,6 +41,7 @@ from .fluid.dataset import * # noqa: F401, F403
from .fluid.lazy_init import LazyGuard # noqa: F401
from .framework.dtype import iinfo # noqa: F401
from .framework.dtype import finfo # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401
from .framework.dtype import int8 # noqa: F401
......@@ -400,6 +401,7 @@ if is_compiled_with_cinn():
disable_static()
__all__ = [ # noqa
'iinfo',
'finfo',
'dtype',
'uint8',
'int8',
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import unittest
from distutils.version import StrictVersion
import numpy as np
......@@ -48,6 +49,64 @@ class TestIInfoAndFInfoAPI(unittest.TestCase):
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.dtype, xninfo.dtype)
def test_finfo(self):
for paddle_dtype, np_dtype in [
(paddle.float32, np.float32),
(paddle.float64, np.float64),
]:
xinfo = paddle.finfo(paddle_dtype)
xninfo = np.finfo(np_dtype)
self.assertEqual(xinfo.dtype, xninfo.dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertAlmostEqual(xinfo.max, xninfo.max)
self.assertAlmostEqual(xinfo.min, xninfo.min)
self.assertAlmostEqual(xinfo.eps, xninfo.eps)
self.assertAlmostEqual(xinfo.tiny, xninfo.tiny)
self.assertAlmostEqual(xinfo.resolution, xninfo.resolution)
if StrictVersion(np.__version__) >= StrictVersion('1.22.0'):
self.assertAlmostEqual(
xinfo.smallest_normal, xninfo.smallest_normal
)
for paddle_dtype, np_dtype in [
(paddle.complex64, np.complex64),
(paddle.complex128, np.complex128),
]:
xinfo = paddle.finfo(paddle_dtype)
xninfo = np.finfo(np_dtype)
self.assertEqual(xinfo.dtype, xninfo.dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertAlmostEqual(xinfo.max, xninfo.max, places=16)
self.assertAlmostEqual(xinfo.min, xninfo.min, places=16)
self.assertAlmostEqual(xinfo.eps, xninfo.eps, places=16)
self.assertAlmostEqual(xinfo.tiny, xninfo.tiny, places=16)
self.assertAlmostEqual(xinfo.resolution, xninfo.resolution)
if StrictVersion(np.__version__) >= StrictVersion('1.22.0'):
self.assertAlmostEqual(
xinfo.smallest_normal, xninfo.smallest_normal, places=16
)
xinfo = paddle.finfo(paddle.float16)
self.assertEqual(xinfo.dtype, "float16")
self.assertEqual(xinfo.bits, 16)
self.assertAlmostEqual(xinfo.max, 65504.0)
self.assertAlmostEqual(xinfo.min, -65504.0)
self.assertAlmostEqual(xinfo.eps, 0.0009765625)
self.assertAlmostEqual(xinfo.tiny, 6.103515625e-05)
self.assertAlmostEqual(xinfo.resolution, 0.001)
self.assertAlmostEqual(xinfo.smallest_normal, 6.103515625e-05)
xinfo = paddle.finfo(paddle.bfloat16)
self.assertEqual(xinfo.dtype, "bfloat16")
self.assertEqual(xinfo.bits, 16)
self.assertAlmostEqual(xinfo.max, 3.3895313892515355e38)
self.assertAlmostEqual(xinfo.min, -3.3895313892515355e38)
self.assertAlmostEqual(xinfo.eps, 0.0078125)
self.assertAlmostEqual(xinfo.tiny, 1.1754943508222875e-38)
self.assertAlmostEqual(xinfo.resolution, 0.01)
self.assertAlmostEqual(xinfo.smallest_normal, 1.1754943508222875e-38)
if __name__ == '__main__':
unittest.main()
......@@ -13,6 +13,7 @@
# limitations under the License.
from ..fluid.core import VarDesc
from ..fluid.core import finfo as core_finfo
from ..fluid.core import iinfo as core_iinfo
dtype = VarDesc.VarType
......@@ -69,3 +70,45 @@ def iinfo(dtype):
"""
return core_iinfo(dtype)
def finfo(dtype):
"""
``paddle.finfo`` is a function that returns an object that represents the numerical properties of a floating point
``paddle.dtype``.
This is similar to `numpy.finfo <https://numpy.org/doc/stable/reference/generated/numpy.finfo.html#numpy-finfo>`_.
Args:
dtype(paddle.dtype): One of ``paddle.float16``, ``paddle.float32``, ``paddle.float64``, ``paddle.bfloat16``,
``paddle.complex64``, and ``paddle.complex128``.
Returns:
An ``finfo`` object, which has the following 8 attributes:
- min(double): The smallest representable number (typically `-max`).
- max(double): The largest representable number.
- eps(double): The smallest representable number such that `1.0 + eps ≠ 1.0`.
- resolution(double): The approximate decimal resolution of this type, i.e., `10**-precision`.
- smallest_normal(double): The smallest positive normal number.
- tiny(double): The smallest positive normal number. Equivalent to smallest_normal.
- bits(int): The number of bits occupied by the type.
- dtype(str): The string name of the argument dtype.
Examples:
.. code-block:: python
import paddle
finfo_float32 = paddle.finfo(paddle.float32)
print(finfo_float32.min) # -3.40282e+38
print(finfo_float32.max) # 3.40282e+38
print(finfo_float32.eps) # 1.19209e-07
print(finfo_float32.resolution) # 1e-06
print(finfo_float32.smallest_normal) # 1.17549e-38
print(finfo_float32.tiny) # 1.17549e-38
print(finfo_float32.bits) # 32
print(finfo_float32.dtype) # float32
"""
return core_finfo(dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册