Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2bf61284
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2bf61284
编写于
5月 08, 2023
作者:
Z
Zhang Ting
提交者:
GitHub
5月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AMP] fix static promote (#53439)
上级
3fd2e765
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
87 addition
and
52 deletion
+87
-52
python/paddle/static/amp/fp16_lists.py
python/paddle/static/amp/fp16_lists.py
+34
-34
python/paddle/static/amp/fp16_utils.py
python/paddle/static/amp/fp16_utils.py
+16
-4
test/amp/test_amp_list.py
test/amp/test_amp_list.py
+4
-3
test/amp/test_amp_promote.py
test/amp/test_amp_promote.py
+0
-1
test/amp/test_model_cast_to_bf16.py
test/amp/test_model_cast_to_bf16.py
+5
-3
test/contrib/test_image_classification_fp16.py
test/contrib/test_image_classification_fp16.py
+28
-7
未找到文件。
python/paddle/static/amp/fp16_lists.py
浏览文件 @
2bf61284
...
...
@@ -23,11 +23,15 @@ _logger = get_logger(
)
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_
unsupported
_list
=
{
_extra_
black
_list
=
{
'lookup_table'
,
'lookup_table_v2'
,
'scatter'
,
'scatter_grad'
,
'linear_interp_v2'
,
'nearest_interp_v2'
,
'bilinear_interp_v2'
,
'bicubic_interp_v2'
,
'trilinear_interp_v2'
,
}
...
...
@@ -118,8 +122,7 @@ def _get_sys_unsupported_list(dtype):
def
_get_unsupported_list
(
dtype
):
# The set of ops that don't support fp16 calculation
_
,
_sys_unsupported_list
=
_get_sys_unsupported_list
(
dtype
)
unsupported_list
=
_extra_unsupported_list
|
_sys_unsupported_list
return
unsupported_list
return
_sys_unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
...
...
@@ -145,6 +148,32 @@ def _get_white_list(dtype):
return
white_list_for_dtype
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list
=
{
'exp'
,
'square'
,
'log'
,
'mean'
,
'sum'
,
'cos_sim'
,
'softmax'
,
'softmax_with_cross_entropy'
,
'sigmoid_cross_entropy_with_logits'
,
'c_softmax_with_cross_entropy'
,
'cross_entropy'
,
'cross_entropy2'
,
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum'
,
}
def
_get_black_list
():
_black_list
=
copy
.
copy
(
black_list
)
_black_list
=
_black_list
|
_extra_black_list
return
_black_list
class
AutoMixedPrecisionLists
:
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
...
...
@@ -170,7 +199,7 @@ class AutoMixedPrecisionLists:
self
.
_custom_white_list
=
custom_white_list
self
.
_custom_black_list
=
custom_black_list
self
.
white_list
=
copy
.
copy
(
_get_white_list
(
self
.
amp_dtype
))
self
.
black_list
=
copy
.
copy
(
black_list
)
self
.
black_list
=
copy
.
copy
(
_get_black_list
()
)
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
unsupported_list
=
copy
.
copy
(
_get_unsupported_list
(
self
.
amp_dtype
))
self
.
black_varnames
=
copy
.
copy
(
custom_black_varnames
)
...
...
@@ -196,8 +225,6 @@ class AutoMixedPrecisionLists:
elif
op_name
in
self
.
gray_list
:
self
.
gray_list
.
remove
(
op_name
)
self
.
white_list
.
add
(
op_name
)
if
op_name
in
_extra_unsupported_list
:
self
.
unsupported_list
.
remove
(
op_name
)
if
self
.
_custom_black_list
:
for
op_name
in
self
.
_custom_black_list
:
if
op_name
in
self
.
white_list
:
...
...
@@ -217,33 +244,6 @@ class AutoMixedPrecisionLists:
)
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list
=
{
'exp'
,
'square'
,
'log'
,
'mean'
,
'sum'
,
'cos_sim'
,
'softmax'
,
'softmax_with_cross_entropy'
,
'sigmoid_cross_entropy_with_logits'
,
'c_softmax_with_cross_entropy'
,
'cross_entropy'
,
'cross_entropy2'
,
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table'
,
'lookup_table_v2'
,
'linear_interp_v2'
,
'nearest_interp_v2'
,
'bilinear_interp_v2'
,
'bicubic_interp_v2'
,
'trilinear_interp_v2'
,
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum'
,
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
...
...
python/paddle/static/amp/fp16_utils.py
浏览文件 @
2bf61284
...
...
@@ -425,15 +425,22 @@ def set_var_dst_dtype(
def
set_param_dtype
(
program
,
dtype
,
amp_lists
,
use_fp16_guard
,
level
):
if
level
==
"O1"
:
return
keep_fp32_var_names
=
set
()
if
level
==
"O1"
:
return
keep_fp32_var_names
all_parameters
=
[]
for
block
in
program
.
blocks
:
all_parameters
.
extend
(
block
.
all_parameters
())
ops
=
block
.
ops
for
op
in
ops
:
if
op_need_keep_fp32
(
op
,
amp_lists
,
use_fp16_guard
):
# Currently, lookup_table is in black_list and unsupport_list, it's weight will be
# set to fp32 in setp 1 of cast_model_tp_fp16. But the weight may be used as matmul's
# input in transformer, so the weight is also in to_fp16_var_names.
# TODO(zhangting2020): consider fix auto_parallel_fp16 and remove lookup_table
# from black_list and unsupport_list.
if
op
in
[
'lookup_table'
,
'lookup_table_v2'
]:
continue
if
_need_keep_fp32
(
op
,
amp_lists
.
unsupported_list
,
use_fp16_guard
):
for
in_name
in
op
.
input_names
:
keep_fp32_var_names
=
keep_fp32_var_names
.
union
(
op
.
input
(
in_name
)
...
...
@@ -451,6 +458,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if
param
.
name
not
in
keep_fp32_var_names
:
_logger
.
debug
(
f
"-- set param
{
param
.
name
}
to
{
dtype
}
--."
)
param
.
desc
.
set_dtype
(
dtype
)
return
keep_fp32_var_names
def
op_need_keep_fp32
(
op
,
amp_lists
,
use_fp16_guard
):
...
...
@@ -607,15 +615,17 @@ def cast_model_to_fp16(
keep_fp32_ops
=
set
()
keep_fp16_ops
=
set
()
to_fp16_var_names
=
set
()
keep_fp32_var_names
=
set
()
# step 1: set params dtype.
set_param_dtype
(
fp32_var_names
=
set_param_dtype
(
program
,
dtype
=
dest_type
,
amp_lists
=
amp_lists
,
use_fp16_guard
=
use_fp16_guard
,
level
=
level
,
)
keep_fp32_var_names
=
keep_fp32_var_names
.
union
(
fp32_var_names
)
def
need_process
(
op
):
need_process
=
True
...
...
@@ -719,6 +729,8 @@ def cast_model_to_fp16(
idx
+=
num_cast_ops
+
1
_logger
.
debug
(
"---- after cast model to fp16 ----"
)
_logger
.
debug
(
program
)
to_fp16_var_names
.
difference_update
(
keep_fp32_var_names
)
return
to_fp16_var_names
...
...
test/amp/test_amp_list.py
浏览文件 @
2bf61284
...
...
@@ -50,9 +50,10 @@ class TestAMPList(unittest.TestCase):
self
.
check_if_op_not_in_list
(
self
.
custom_white_list
,
amp_list
.
black_list
)
self
.
check_if_op_not_in_list
(
self
.
custom_white_list
,
amp_list
.
unsupported_list
)
if
paddle
.
amp
.
is_float16_supported
():
self
.
check_if_op_not_in_list
(
self
.
custom_white_list
,
amp_list
.
black_list
)
def
test_eager
(
self
):
if
not
paddle
.
amp
.
is_float16_supported
():
...
...
test/amp/test_amp_promote.py
浏览文件 @
2bf61284
...
...
@@ -48,7 +48,6 @@ class TestAMPPromote(AmpTestBase):
max_iters
=
2
x_fp32
=
np
.
random
.
random
(
size
=
[
1
,
1
,
6
,
6
]).
astype
(
"float32"
)
print
(
main_program
)
losses_o1
=
self
.
run_program
(
main_program
,
startup_program
,
...
...
test/amp/test_model_cast_to_bf16.py
浏览文件 @
2bf61284
...
...
@@ -265,18 +265,20 @@ class TestProgramBF16(AmpTestBase):
amp
.
debugging
.
collect_operator_stats
(
main_program
)
op_stats_list
=
amp
.
debugging
.
_get_op_stats_list
(
main_program
)
expected_fp32_calls
=
{
"lookup_table_v2"
:
1
}
expected_bf16_calls
=
{
"matmul_v2"
:
1
,
"elementwise_add"
:
1
,
"dropout"
:
1
,
"lookup_table_v2"
:
0
,
"squared_l2_norm"
:
2
,
"adamw"
:
2
,
"squared_l2_norm"
:
3
,
"adamw"
:
3
,
}
self
.
_check_optimizer
(
main_program
,
expected_bf16_calls
[
"matmul_v2"
]
+
expected_bf16_calls
[
"elementwise_add"
],
+
expected_bf16_calls
[
"elementwise_add"
]
+
expected_fp32_calls
[
"lookup_table_v2"
],
)
self
.
_check_op_calls
(
op_stats_list
[
0
],
expected_bf16_calls
)
...
...
test/contrib/test_image_classification_fp16.py
浏览文件 @
2bf61284
...
...
@@ -318,7 +318,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
amp_lists
=
paddle
.
static
.
amp
.
AutoMixedPrecisionLists
()
...
...
@@ -331,7 +334,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
# 1. w={'exp}, b=None
...
...
@@ -348,7 +354,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
# 2. w={'tanh'}, b=None
...
...
@@ -365,7 +374,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
# 3. w={'lstm'}, b=None
...
...
@@ -381,7 +393,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
# 4. w=None, b={'conv2d'}
...
...
@@ -400,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
# 5. w=None, b={'tanh'}
...
...
@@ -419,7 +437,10 @@ class TestImageClassification(unittest.TestCase):
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
white_list
)
|
paddle
.
static
.
amp
.
fp16_lists
.
_only_supported_fp16_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
)
black_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
black_list
|
paddle
.
static
.
amp
.
fp16_lists
.
_extra_black_list
)
gray_list
=
copy
.
copy
(
paddle
.
static
.
amp
.
fp16_lists
.
gray_list
)
# 6. w=None, b={'lstm'}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录