Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
195736cf
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
195736cf
编写于
12月 21, 2022
作者:
N
niuliling123
提交者:
GitHub
12月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add FLAGS_low_precision_op_list to get amp list of current module (#48843)
上级
b51a752f
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
104 addition
and
0 deletion
+104
-0
paddle/fluid/eager/amp_utils.h
paddle/fluid/eager/amp_utils.h
+10
-0
paddle/fluid/imperative/amp_auto_cast.cc
paddle/fluid/imperative/amp_auto_cast.cc
+11
-0
paddle/fluid/imperative/amp_auto_cast.h
paddle/fluid/imperative/amp_auto_cast.h
+7
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
paddle/phi/core/flags.cc
paddle/phi/core/flags.cc
+14
-0
python/paddle/fluid/dygraph/amp/auto_cast.py
python/paddle/fluid/dygraph/amp/auto_cast.py
+15
-0
python/paddle/fluid/tests/unittests/test_low_precision_list.py
...n/paddle/fluid/tests/unittests/test_low_precision_list.py
+43
-0
未找到文件。
paddle/fluid/eager/amp_utils.h
浏览文件 @
195736cf
...
...
@@ -100,6 +100,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if
(
paddle
::
imperative
::
AmpOperators
::
Instance
()
.
GetMutableAllowOps
()
->
count
(
op_name
))
{
paddle
::
imperative
::
AmpOperators
::
Instance
().
AddToAmpOpList
(
op_name
);
return
paddle
::
experimental
::
DataType
::
FLOAT16
;
}
else
if
(
paddle
::
imperative
::
AmpOperators
::
Instance
()
.
GetMutableBlockOps
()
...
...
@@ -117,6 +118,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.
GetMutableUnsupportedFp16Ops
()
->
count
(
op_name
))
{
dst_type
=
paddle
::
experimental
::
DataType
::
FLOAT32
;
}
else
{
paddle
::
imperative
::
AmpOperators
::
Instance
().
AddToAmpOpList
(
op_name
);
}
return
dst_type
;
}
...
...
@@ -129,6 +132,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.
GetMutableBlockOps
()
->
count
(
op_name
))
{
dst_type
=
paddle
::
experimental
::
DataType
::
FLOAT32
;
}
else
{
paddle
::
imperative
::
AmpOperators
::
Instance
().
AddToAmpOpList
(
op_name
);
}
return
dst_type
;
}
...
...
@@ -137,6 +142,7 @@ inline paddle::experimental::DataType GetAmpDestDtype(
if
(
paddle
::
imperative
::
AmpOperators
::
Instance
()
.
GetMutableAllowOps
()
->
count
(
op_name
))
{
paddle
::
imperative
::
AmpOperators
::
Instance
().
AddToAmpOpList
(
op_name
);
return
paddle
::
experimental
::
DataType
::
BFLOAT16
;
}
else
if
(
paddle
::
imperative
::
AmpOperators
::
Instance
()
.
GetMutableBlockOps
()
...
...
@@ -152,6 +158,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.
GetMutableUnsupportedBf16Ops
()
->
count
(
op_name
))
{
dst_type
=
paddle
::
experimental
::
DataType
::
FLOAT32
;
}
else
{
paddle
::
imperative
::
AmpOperators
::
Instance
().
AddToAmpOpList
(
op_name
);
}
return
dst_type
;
}
...
...
@@ -164,6 +172,8 @@ inline paddle::experimental::DataType GetAmpDestDtype(
.
GetMutableBlockOps
()
->
count
(
op_name
))
{
dst_type
=
paddle
::
experimental
::
DataType
::
FLOAT32
;
}
else
{
paddle
::
imperative
::
AmpOperators
::
Instance
().
AddToAmpOpList
(
op_name
);
}
return
dst_type
;
}
...
...
paddle/fluid/imperative/amp_auto_cast.cc
浏览文件 @
195736cf
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
DECLARE_bool
(
low_precision_op_list
);
namespace
paddle
{
namespace
imperative
{
...
...
@@ -193,6 +194,16 @@ AmpOperators::GetMutableUnsupportedBf16Ops() {
return
unsupported_bf16_ops_
;
}
void
AmpOperators
::
AddToAmpOpList
(
const
std
::
string
&
op_name
)
{
if
(
FLAGS_low_precision_op_list
)
{
current_amp_ops_
[
op_name
]
+=
1
;
}
}
std
::
map
<
const
std
::
string
,
int
>
AmpOperators
::
GetAmpOpList
()
{
return
current_amp_ops_
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
AmpOperators
&
ops
)
{
os
<<
"allow ops: "
;
auto
allow_ops
=
ops
.
GetMutableAllowOps
();
...
...
paddle/fluid/imperative/amp_auto_cast.h
浏览文件 @
195736cf
...
...
@@ -60,6 +60,10 @@ class AmpOperators {
std
::
shared_ptr
<
std
::
unordered_set
<
std
::
string
>>
GetMutableUnsupportedBf16Ops
();
void
AddToAmpOpList
(
const
std
::
string
&
op_name
);
std
::
map
<
const
std
::
string
,
int
>
GetAmpOpList
();
private:
AmpOperators
();
// forbid calling default constructor
...
...
@@ -76,6 +80,9 @@ class AmpOperators {
// The set of ops that has no bf16 CUDA kennel.
std
::
shared_ptr
<
std
::
unordered_set
<
std
::
string
>>
unsupported_bf16_ops_
;
// The amp op list of current module.
std
::
map
<
const
std
::
string
,
int
>
current_amp_ops_
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
AmpOperators
&
ops
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
195736cf
...
...
@@ -2545,6 +2545,10 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"update_autotune_status"
,
[]
{
return
phi
::
autotune
::
AutoTuneStatus
::
Instance
().
Update
();
});
m
.
def
(
"get_low_precision_op_list"
,
[]
{
return
paddle
::
imperative
::
AmpOperators
::
Instance
().
GetAmpOpList
();
});
m
.
def
(
"autotune_status"
,
[]
{
py
::
dict
res
;
phi
::
autotune
::
AutoTuneCache
::
Instance
().
UpdateStatus
();
...
...
paddle/phi/core/flags.cc
浏览文件 @
195736cf
...
...
@@ -52,6 +52,20 @@ PADDLE_DEFINE_EXPORTED_int32(paddle_num_threads,
1
,
"Number of threads for each paddle instance."
);
/**
* Low Precision Op related FLAG
* Name: FLAGS_low_precision_op_list
* Since Version: 0.13.0
* Value Range: bool, default=false
* Example:
* Note: Used to debug. Get the low precision op list of current module.
*/
PADDLE_DEFINE_EXPORTED_bool
(
low_precision_op_list
,
false
,
"Checking whether get the low precision op list of "
"current module. It will be "
"rerun the low precision list after module."
);
/**
* Operator related FLAG
* Name: FLAGS_check_nan_inf
...
...
python/paddle/fluid/dygraph/amp/auto_cast.py
浏览文件 @
195736cf
...
...
@@ -110,6 +110,21 @@ PURE_BF16_BLACK_LIST = set()
_g_amp_state_
=
None
def
low_precision_op_list
():
op_list
=
paddle
.
fluid
.
core
.
get_low_precision_op_list
()
op_count
=
0
print
(
'<---------------- low precision op list ------------------->'
)
print
(
'<---- op name ------|------- op count---------------------->'
)
for
x
in
op_list
:
print
(
' %-18s| %4d'
%
(
x
,
op_list
[
x
]))
op_count
+=
1
print
(
'<------------- low precision op num:{:5d} ----------------->'
.
format
(
op_count
)
)
def
amp_state
():
global
_g_amp_state_
return
_g_amp_state_
...
...
python/paddle/fluid/tests/unittests/test_low_precision_list.py
0 → 100644
浏览文件 @
195736cf
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
class
TestAMPList
(
unittest
.
TestCase
):
def
test_main
(
self
):
conv2d
=
paddle
.
nn
.
Conv2D
(
3
,
2
,
3
,
bias_attr
=
False
)
data
=
paddle
.
rand
([
10
,
3
,
32
,
32
])
paddle
.
set_flags
({
'FLAGS_low_precision_op_list'
:
1
})
a
=
paddle
.
rand
([
2
,
3
])
b
=
paddle
.
rand
([
2
,
3
])
# amp list conv2d, cast
with
paddle
.
amp
.
auto_cast
():
conv
=
conv2d
(
data
)
c
=
a
+
b
paddle
.
fluid
.
dygraph
.
amp
.
auto_cast
.
low_precision_op_list
()
op_list
=
paddle
.
fluid
.
core
.
get_low_precision_op_list
()
print
(
conv
.
dtype
)
if
conv
.
dtype
==
paddle
.
float16
:
self
.
assertTrue
(
'elementwise_add'
in
op_list
)
self
.
assertTrue
(
'conv2d'
in
op_list
)
self
.
assertTrue
(
2
==
len
(
op_list
))
else
:
self
.
assertTrue
(
0
==
len
(
op_list
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录