Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
395520f1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
395520f1
编写于
1月 11, 2023
作者:
N
niuliling123
提交者:
GitHub
1月 11, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update the style of print for low precision op list (#49648)
上级
18a7e13f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
76 addition
and
26 deletion
+76
-26
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+11
-1
paddle/phi/core/kernel_factory.cc
paddle/phi/core/kernel_factory.cc
+15
-8
paddle/phi/core/kernel_factory.h
paddle/phi/core/kernel_factory.h
+15
-2
python/paddle/amp/auto_cast.py
python/paddle/amp/auto_cast.py
+17
-9
python/paddle/fluid/tests/unittests/test_low_precision_list.py
...n/paddle/fluid/tests/unittests/test_low_precision_list.py
+18
-6
未找到文件。
paddle/fluid/pybind/pybind.cc
浏览文件 @
395520f1
...
...
@@ -2595,7 +2595,17 @@ All parameter, weight, gradient are variables in Paddle.
[]
{
return
phi
::
autotune
::
AutoTuneStatus
::
Instance
().
Update
();
});
m
.
def
(
"get_low_precision_op_list"
,
[]
{
return
phi
::
KernelFactory
::
Instance
().
GetLowPrecisionKernelList
();
py
::
dict
op_list
;
auto
list_op
=
phi
::
KernelFactory
::
Instance
().
GetLowPrecisionKernelList
();
for
(
auto
iter
=
list_op
.
begin
();
iter
!=
list_op
.
end
();
iter
++
)
{
auto
op_name
=
(
iter
->
first
).
c_str
();
auto
counts
=
iter
->
second
;
op_list
[
op_name
]
=
std
::
to_string
(
counts
.
fp16_called_
)
+
","
+
std
::
to_string
(
counts
.
bf16_called_
)
+
","
+
std
::
to_string
(
counts
.
fp32_called_
)
+
","
+
std
::
to_string
(
counts
.
other_called_
);
}
return
op_list
;
});
m
.
def
(
"autotune_status"
,
[]
{
...
...
paddle/phi/core/kernel_factory.cc
浏览文件 @
395520f1
...
...
@@ -115,18 +115,25 @@ void KernelFactory::AddToLowPrecisionKernelList(
if
(
op_name
.
find
(
"_grad"
)
!=
std
::
string
::
npos
)
{
return
;
// only record forward api
}
bool
is_low_precision
=
(
kernel_key_type
==
paddle
::
experimental
::
DataType
::
FLOAT16
||
kernel_key_type
==
paddle
::
experimental
::
DataType
::
BFLOAT16
);
bool
need_record
=
FLAGS_low_precision_op_list
==
1
?
is_low_precision
:
true
;
if
(
need_record
)
{
low_precision_kernels_
[
op_name
]
+=
1
;
if
(
low_precision_kernels_
.
find
(
op_name
)
==
low_precision_kernels_
.
end
())
{
auto
count
=
OpCount
();
low_precision_kernels_
[
op_name
]
=
count
;
}
if
(
kernel_key_type
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
low_precision_kernels_
[
op_name
].
fp16_called_
+=
1
;
}
else
if
(
kernel_key_type
==
paddle
::
experimental
::
DataType
::
BFLOAT16
)
{
low_precision_kernels_
[
op_name
].
bf16_called_
+=
1
;
}
else
if
(
kernel_key_type
==
paddle
::
experimental
::
DataType
::
FLOAT32
)
{
low_precision_kernels_
[
op_name
].
fp32_called_
+=
1
;
}
else
{
low_precision_kernels_
[
op_name
].
other_called_
+=
1
;
}
}
}
std
::
map
<
const
std
::
string
,
int
>
KernelFactory
::
GetLowPrecisionKernelList
()
{
std
::
map
<
const
std
::
string
,
OpCount
>
KernelFactory
::
GetLowPrecisionKernelList
()
{
return
low_precision_kernels_
;
}
...
...
paddle/phi/core/kernel_factory.h
浏览文件 @
395520f1
...
...
@@ -34,6 +34,19 @@ namespace phi {
using
DataType
=
paddle
::
experimental
::
DataType
;
struct
OpCount
{
OpCount
()
{
fp16_called_
=
0
;
bf16_called_
=
0
;
fp32_called_
=
0
;
other_called_
=
0
;
}
int
fp16_called_
;
int
bf16_called_
;
int
fp32_called_
;
int
other_called_
;
};
/**
* [ Naming considerations ]
*
...
...
@@ -309,7 +322,7 @@ class KernelFactory {
const
std
::
string
&
name
,
const
paddle
::
experimental
::
DataType
&
kernel_key_type
);
std
::
map
<
const
std
::
string
,
i
nt
>
GetLowPrecisionKernelList
();
std
::
map
<
const
std
::
string
,
OpCou
nt
>
GetLowPrecisionKernelList
();
private:
KernelFactory
()
=
default
;
...
...
@@ -317,7 +330,7 @@ class KernelFactory {
KernelNameMap
kernels_
;
// Get the low precision kernel list of current module.
std
::
map
<
const
std
::
string
,
i
nt
>
low_precision_kernels_
;
std
::
map
<
const
std
::
string
,
OpCou
nt
>
low_precision_kernels_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelKey
&
kernel_key
)
{
...
...
python/paddle/amp/auto_cast.py
浏览文件 @
395520f1
...
...
@@ -97,21 +97,29 @@ _g_amp_state_ = None
def
low_precision_op_list
():
if
os
.
getenv
(
"FLAGS_low_precision_op_list"
)
is
not
None
:
level
=
int
(
os
.
getenv
(
"FLAGS_low_precision_op_list"
))
if
level
==
0
:
return
if
level
==
1
:
print
(
'<{:-^60}>'
.
format
(
" low precision op list "
))
else
:
print
(
'<{:-^60}>'
.
format
(
" op list "
))
print
(
'<{:-^120}>'
.
format
(
" op list "
))
op_list
=
paddle
.
fluid
.
core
.
get_low_precision_op_list
()
op_count
=
0
print
(
'<{:-^40}'
.
format
(
" op_name "
),
'|'
,
'{:-^17}>'
.
format
(
" op count "
)
'<{:-^40}'
.
format
(
" Op Name "
),
'|'
,
'{:-^17}'
.
format
(
"FP16 Calls"
),
'|'
,
'{:-^17}'
.
format
(
"BF16 Calls"
),
'|'
,
'{:-^17}'
.
format
(
'FP32 Calls'
),
'|'
,
'{:-^17}>'
.
format
(
'Other Calls'
),
)
for
x
in
op_list
:
print
(
' %-40s| %-15d'
%
(
x
,
op_list
[
x
]))
# fp16, bf16, fp32, other
called
=
op_list
[
x
].
split
(
","
)
print
(
' %-40s| %-17s| %-17s| %-17s| %-17s'
%
(
x
,
called
[
0
],
called
[
1
],
called
[
2
],
called
[
3
])
)
op_count
+=
1
print
(
'<{:-^
6
0}>'
.
format
(
" op count: "
+
str
(
op_count
)
+
" "
))
print
(
'<{:-^
12
0}>'
.
format
(
" op count: "
+
str
(
op_count
)
+
" "
))
def
amp_state
():
...
...
python/paddle/fluid/tests/unittests/test_low_precision_list.py
浏览文件 @
395520f1
...
...
@@ -30,12 +30,24 @@ class TestAMPList(unittest.TestCase):
c
=
a
+
b
paddle
.
amp
.
low_precision_op_list
()
op_list
=
paddle
.
fluid
.
core
.
get_low_precision_op_list
()
if
conv
.
dtype
==
paddle
.
float16
:
self
.
assertTrue
(
'elementwise_add'
in
op_list
)
self
.
assertTrue
(
'conv2d'
in
op_list
)
self
.
assertTrue
(
2
==
len
(
op_list
))
else
:
self
.
assertTrue
(
0
==
len
(
op_list
))
self
.
assertTrue
(
'elementwise_add'
in
op_list
)
self
.
assertTrue
(
'conv2d'
in
op_list
)
conv2d_called
=
op_list
[
'conv2d'
].
split
(
','
)
add_called
=
op_list
[
'elementwise_add'
].
split
(
','
)
add_num
=
0
conv_num
=
0
for
i
in
range
(
4
):
add_num
+=
int
(
add_called
[
i
])
conv_num
+=
int
(
add_called
[
i
])
self
.
assertTrue
(
conv_num
==
1
)
self
.
assertTrue
(
add_num
==
1
)
if
conv
.
dtype
==
"float16"
:
self
.
assertTrue
(
int
(
conv2d_called
[
0
])
==
1
)
self
.
assertTrue
(
int
(
add_called
[
0
])
==
1
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录