Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
40a0a46b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40a0a46b
编写于
9月 08, 2022
作者:
O
OccupyMars2025
提交者:
GitHub
9月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ Hackathon 3rd No.2 ] add paddle.iinfo (#45321)
上级
a642365e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
145 addition
and
2 deletion
+145
-2
paddle/fluid/eager/grad_node_info.h
paddle/fluid/eager/grad_node_info.h
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+62
-0
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py
python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py
+45
-0
python/paddle/framework/dtype.py
python/paddle/framework/dtype.py
+35
-1
未找到文件。
paddle/fluid/eager/grad_node_info.h
浏览文件 @
40a0a46b
...
@@ -302,7 +302,7 @@ class GradNodeBase {
...
@@ -302,7 +302,7 @@ class GradNodeBase {
// Gradient Hooks
// Gradient Hooks
// Customer may register a list of hooks which will be called in order during
// Customer may register a list of hooks which will be called in order during
// backward
// backward
// Each entry consists one pair of
// Each entry consists o
f o
ne pair of
// <hook_id, <out_rank, std::shared_ptr<TensorHook>>>
// <hook_id, <out_rank, std::shared_ptr<TensorHook>>>
std
::
map
<
int64_t
,
std
::
map
<
int64_t
,
std
::
tuple
<
std
::
tuple
<
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
40a0a46b
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <mutex> // NOLINT // for call_once
#include <sstream>
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <type_traits>
#include <type_traits>
...
@@ -346,6 +347,52 @@ bool IsCompiledWithDIST() {
...
@@ -346,6 +347,52 @@ bool IsCompiledWithDIST() {
#endif
#endif
}
}
struct
iinfo
{
int64_t
min
,
max
;
int
bits
;
std
::
string
dtype
;
explicit
iinfo
(
const
framework
::
proto
::
VarType
::
Type
&
type
)
{
switch
(
type
)
{
case
framework
::
proto
::
VarType
::
INT16
:
min
=
std
::
numeric_limits
<
int16_t
>::
min
();
max
=
std
::
numeric_limits
<
int16_t
>::
max
();
bits
=
16
;
dtype
=
"int16"
;
break
;
case
framework
::
proto
::
VarType
::
INT32
:
min
=
std
::
numeric_limits
<
int32_t
>::
min
();
max
=
std
::
numeric_limits
<
int32_t
>::
max
();
bits
=
32
;
dtype
=
"int32"
;
break
;
case
framework
::
proto
::
VarType
::
INT64
:
min
=
std
::
numeric_limits
<
int64_t
>::
min
();
max
=
std
::
numeric_limits
<
int64_t
>::
max
();
bits
=
64
;
dtype
=
"int64"
;
break
;
case
framework
::
proto
::
VarType
::
INT8
:
min
=
std
::
numeric_limits
<
int8_t
>::
min
();
max
=
std
::
numeric_limits
<
int8_t
>::
max
();
bits
=
8
;
dtype
=
"int8"
;
break
;
case
framework
::
proto
::
VarType
::
UINT8
:
min
=
std
::
numeric_limits
<
uint8_t
>::
min
();
max
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bits
=
8
;
dtype
=
"uint8"
;
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"the argument of paddle.iinfo can only be paddle.int8, "
"paddle.int16, paddle.int32, paddle.int64, or paddle.uint8"
));
break
;
}
}
};
static
PyObject
*
GetPythonAttribute
(
PyObject
*
obj
,
const
char
*
attr_name
)
{
static
PyObject
*
GetPythonAttribute
(
PyObject
*
obj
,
const
char
*
attr_name
)
{
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
// is not inside obj, but it would also set the error flag of Python.
...
@@ -555,6 +602,21 @@ PYBIND11_MODULE(core_noavx, m) {
...
@@ -555,6 +602,21 @@ PYBIND11_MODULE(core_noavx, m) {
BindException
(
&
m
);
BindException
(
&
m
);
py
::
class_
<
iinfo
>
(
m
,
"iinfo"
)
.
def
(
py
::
init
<
const
framework
::
proto
::
VarType
::
Type
&>
())
.
def_readonly
(
"min"
,
&
iinfo
::
min
)
.
def_readonly
(
"max"
,
&
iinfo
::
max
)
.
def_readonly
(
"bits"
,
&
iinfo
::
bits
)
.
def_readonly
(
"dtype"
,
&
iinfo
::
dtype
)
.
def
(
"__repr__"
,
[](
const
iinfo
&
a
)
{
std
::
ostringstream
oss
;
oss
<<
"paddle.iinfo(min="
<<
a
.
min
;
oss
<<
", max="
<<
a
.
max
;
oss
<<
", bits="
<<
a
.
bits
;
oss
<<
", dtype="
<<
a
.
dtype
<<
")"
;
return
oss
.
str
();
});
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
m
.
def
(
"disable_signal_handler"
,
&
DisableSignalHandler
);
m
.
def
(
"disable_signal_handler"
,
&
DisableSignalHandler
);
...
...
python/paddle/__init__.py
浏览文件 @
40a0a46b
...
@@ -38,6 +38,7 @@ from .framework import in_dynamic_mode # noqa: F401
...
@@ -38,6 +38,7 @@ from .framework import in_dynamic_mode # noqa: F401
from
.fluid.dataset
import
*
# noqa: F401
from
.fluid.dataset
import
*
# noqa: F401
from
.fluid.lazy_init
import
LazyGuard
# noqa: F401
from
.fluid.lazy_init
import
LazyGuard
# noqa: F401
from
.framework.dtype
import
iinfo
# noqa: F401
from
.framework.dtype
import
dtype
as
dtype
# noqa: F401
from
.framework.dtype
import
dtype
as
dtype
# noqa: F401
from
.framework.dtype
import
uint8
# noqa: F401
from
.framework.dtype
import
uint8
# noqa: F401
from
.framework.dtype
import
int8
# noqa: F401
from
.framework.dtype
import
int8
# noqa: F401
...
@@ -386,6 +387,7 @@ if is_compiled_with_cinn():
...
@@ -386,6 +387,7 @@ if is_compiled_with_cinn():
disable_static
()
disable_static
()
__all__
=
[
# noqa
__all__
=
[
# noqa
'iinfo'
,
'dtype'
,
'dtype'
,
'uint8'
,
'uint8'
,
'int8'
,
'int8'
,
...
...
python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py
0 → 100644
浏览文件 @
40a0a46b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
unittest
import
numpy
as
np
class
TestIInfoAndFInfoAPI
(
unittest
.
TestCase
):
def
test_invalid_input
(
self
):
for
dtype
in
[
paddle
.
float16
,
paddle
.
float32
,
paddle
.
float64
,
paddle
.
bfloat16
,
paddle
.
complex64
,
paddle
.
complex128
,
paddle
.
bool
]:
with
self
.
assertRaises
(
ValueError
):
_
=
paddle
.
iinfo
(
dtype
)
def
test_iinfo
(
self
):
for
paddle_dtype
,
np_dtype
in
[(
paddle
.
int64
,
np
.
int64
),
(
paddle
.
int32
,
np
.
int32
),
(
paddle
.
int16
,
np
.
int16
),
(
paddle
.
int8
,
np
.
int8
),
(
paddle
.
uint8
,
np
.
uint8
)]:
xinfo
=
paddle
.
iinfo
(
paddle_dtype
)
xninfo
=
np
.
iinfo
(
np_dtype
)
self
.
assertEqual
(
xinfo
.
bits
,
xninfo
.
bits
)
self
.
assertEqual
(
xinfo
.
max
,
xninfo
.
max
)
self
.
assertEqual
(
xinfo
.
min
,
xninfo
.
min
)
self
.
assertEqual
(
xinfo
.
dtype
,
xninfo
.
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/framework/dtype.py
浏览文件 @
40a0a46b
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
from
..fluid.core
import
VarDesc
from
..fluid.core
import
VarDesc
from
..fluid.core
import
iinfo
as
core_iinfo
dtype
=
VarDesc
.
VarType
dtype
=
VarDesc
.
VarType
dtype
.
__qualname__
=
"dtype"
dtype
.
__qualname__
=
"dtype"
...
@@ -34,4 +35,37 @@ complex128 = VarDesc.VarType.COMPLEX128
...
@@ -34,4 +35,37 @@ complex128 = VarDesc.VarType.COMPLEX128
bool
=
VarDesc
.
VarType
.
BOOL
bool
=
VarDesc
.
VarType
.
BOOL
__all__
=
[]
def
iinfo
(
dtype
):
"""
paddle.iinfo is a function that returns an object that represents the numerical properties of
an integer paddle.dtype.
This is similar to `numpy.iinfo <https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html#numpy-iinfo>`_.
Args:
dtype(paddle.dtype): One of paddle.uint8, paddle.int8, paddle.int16, paddle.int32, and paddle.int64.
Returns:
An iinfo object, which has the following 4 attributes:
- min: int, The smallest representable integer number.
- max: int, The largest representable integer number.
- bits: int, The number of bits occupied by the type.
- dtype: str, The string name of the argument dtype.
Examples:
.. code-block:: python
import paddle
iinfo_uint8 = paddle.iinfo(paddle.uint8)
print(iinfo_uint8)
# paddle.iinfo(min=0, max=255, bits=8, dtype=uint8)
print(iinfo_uint8.min) # 0
print(iinfo_uint8.max) # 255
print(iinfo_uint8.bits) # 8
print(iinfo_uint8.dtype) # uint8
"""
return
core_iinfo
(
dtype
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录