Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
41e90283
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41e90283
编写于
4月 24, 2023
作者:
Z
Zhang Ting
提交者:
GitHub
4月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AMP]expand blacklists for amp training (#50940)
上级
5e1ee106
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
175 addition
and
122 deletion
+175
-122
python/paddle/amp/__init__.py
python/paddle/amp/__init__.py
+2
-4
python/paddle/amp/amp_lists.py
python/paddle/amp/amp_lists.py
+110
-0
python/paddle/amp/auto_cast.py
python/paddle/amp/auto_cast.py
+12
-98
python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py
...e/fleet/test_imperative_auto_mixed_precision_for_eager.py
+4
-4
test/amp/test_amp_list.py
test/amp/test_amp_list.py
+47
-16
未找到文件。
python/paddle/amp/__init__.py
浏览文件 @
41e90283
...
...
@@ -16,10 +16,8 @@ from .auto_cast import auto_cast # noqa: F401
from
.auto_cast
import
decorate
# noqa: F401
from
.auto_cast
import
amp_guard
# noqa: F401
from
.auto_cast
import
amp_decorate
# noqa: F401
from
.auto_cast
import
FP16_WHITE_LIST
# noqa: F401
from
.auto_cast
import
FP16_BLACK_LIST
# noqa: F401
from
.auto_cast
import
PURE_FP16_WHITE_LIST
# noqa: F401
from
.auto_cast
import
PURE_FP16_BLACK_LIST
# noqa: F401
from
.amp_lists
import
white_list
# noqa: F401
from
.amp_lists
import
black_list
# noqa: F401
from
.
import
grad_scaler
# noqa: F401
from
.grad_scaler
import
GradScaler
# noqa: F401
...
...
python/paddle/amp/amp_lists.py
0 → 100644
浏览文件 @
41e90283
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST
=
{
'conv2d'
,
'matmul'
,
'matmul_v2'
,
'max_pool2d_with_index'
,
'mul'
,
'fake_quantize_dequantize_abs_max'
,
'fake_quantize_dequantize_moving_average_abs_max'
,
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST
=
{
'tan'
,
'acos'
,
'asin'
,
'sinh'
,
'cosh'
,
'atanh'
,
'tanh_shrink'
,
'cos_sim'
,
'erfinv'
,
'exp'
,
'expm1'
,
'log'
,
'log10'
,
'log2'
,
'reciprocal'
,
'rsqrt'
,
'pow'
,
'square'
,
'reduce_sum'
,
'mean'
,
'reduce_mean'
,
'reduce_prod'
,
'cumprod'
,
'cumsum'
,
'dist'
,
'pnorm'
,
'frobenius_norm'
,
'renorm'
,
'group_norm'
,
'layer_norm'
,
'softmax'
,
'softmin'
,
'softplus'
,
'log_softmax'
,
'softmax_with_cross_entropy'
,
'sigmoid_cross_entropy_with_logits'
,
'c_softmax_with_cross_entropy'
,
'cross_entropy'
,
'cross_entropy2'
,
'nll_loss'
,
'huber_loss'
,
'triplet_margin_loss'
,
'log_loss'
,
'hsigmoid_loss'
,
'margin_cross_entropy'
,
}
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
FP16_EXTRA_BLACK_LIST
=
{
'linear_interp_v2'
,
'nearest_interp_v2'
,
'bilinear_interp_v2'
,
'bicubic_interp_v2'
,
'trilinear_interp_v2'
,
'lookup_table'
,
'lookup_table_v2'
,
'scatter'
,
'depthwise_conv2d'
,
}
BF16_WHITE_LIST
=
{
'conv2d'
,
'matmul_v2'
}
BF16_BLACK_LIST
=
set
()
def
white_list
():
white_list
=
{
"float16"
:
{
"O1"
:
FP16_WHITE_LIST
,
"O2"
:
FP16_WHITE_LIST
},
"bfloat16"
:
{
"O1"
:
BF16_WHITE_LIST
,
"O2"
:
BF16_WHITE_LIST
},
}
return
white_list
def
black_list
():
black_list
=
{
"float16"
:
{
"O1"
:
FP16_BLACK_LIST
|
FP16_EXTRA_BLACK_LIST
,
"O2"
:
FP16_EXTRA_BLACK_LIST
,
},
"bfloat16"
:
{
"O1"
:
BF16_BLACK_LIST
,
"O2"
:
set
()},
}
return
black_list
python/paddle/amp/auto_cast.py
浏览文件 @
41e90283
...
...
@@ -20,45 +20,7 @@ from paddle.fluid import core
from
paddle.fluid.framework
import
_dygraph_tracer
,
dygraph_only
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
AMP_LEVEL
=
core
.
AmpLevel
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST
=
{
'conv2d'
,
'matmul'
,
'matmul_v2'
,
'max_pool2d_with_index'
,
'mul'
,
'fake_quantize_dequantize_abs_max'
,
'fake_quantize_dequantize_moving_average_abs_max'
,
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST
=
{
'exp'
,
'square'
,
'log'
,
'mean'
,
'sum'
,
'cos_sim'
,
'softmax'
,
'softmax_with_cross_entropy'
,
'sigmoid_cross_entropy_with_logits'
,
'c_softmax_with_cross_entropy'
,
'cross_entropy'
,
'cross_entropy2'
,
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum'
,
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2'
,
'nearest_interp_v2'
,
'bilinear_interp_v2'
,
'bicubic_interp_v2'
,
'trilinear_interp_v2'
,
}
from
.amp_lists
import
black_list
,
white_list
AMP_RELATED_FLAGS
=
[
'FLAGS_cudnn_exhaustive_search'
,
...
...
@@ -72,27 +34,7 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent'
:
1
,
}
PURE_FP16_WHITE_LIST
=
copy
.
copy
(
FP16_WHITE_LIST
)
PURE_FP16_BLACK_LIST
=
{
'lookup_table'
,
'lookup_table_v2'
,
'scatter'
,
'scatter_grad'
,
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2'
,
'nearest_interp_v2'
,
'bilinear_interp_v2'
,
'bicubic_interp_v2'
,
'trilinear_interp_v2'
,
}
BF16_WHITE_LIST
=
{
'conv2d'
,
'matmul_v2'
}
BF16_BLACK_LIST
=
set
()
PURE_BF16_WHITE_LIST
=
copy
.
copy
(
BF16_WHITE_LIST
)
PURE_BF16_BLACK_LIST
=
set
()
AMP_LEVEL
=
core
.
AmpLevel
_g_amp_state_
=
None
...
...
@@ -126,20 +68,12 @@ def _update_list(
"""
Update black and white list according to users' custom list.
"""
if
dtype
==
'float16'
:
if
level
==
'O1'
:
_white_list
=
copy
.
copy
(
FP16_WHITE_LIST
)
_black_list
=
copy
.
copy
(
FP16_BLACK_LIST
)
else
:
_white_list
=
copy
.
copy
(
PURE_FP16_WHITE_LIST
)
_black_list
=
copy
.
copy
(
PURE_FP16_BLACK_LIST
)
else
:
if
level
==
'O1'
:
_white_list
=
copy
.
copy
(
BF16_WHITE_LIST
)
_black_list
=
copy
.
copy
(
BF16_BLACK_LIST
)
else
:
_white_list
=
copy
.
copy
(
PURE_BF16_WHITE_LIST
)
_black_list
=
copy
.
copy
(
PURE_BF16_BLACK_LIST
)
if
level
==
'O0'
:
_white_list
=
set
()
_black_list
=
set
()
return
_white_list
,
_black_list
_white_list
=
copy
.
copy
(
white_list
()[
dtype
][
level
])
_black_list
=
copy
.
copy
(
black_list
()[
dtype
][
level
])
if
custom_white_list
and
custom_black_list
:
for
op_name
in
custom_white_list
:
if
op_name
in
custom_black_list
:
...
...
@@ -453,34 +387,14 @@ def amp_guard(
if
level
==
'O1'
:
amp_level
=
AMP_LEVEL
.
O1
if
dtype
==
'float16'
:
_white_list
=
FP16_WHITE_LIST
_black_list
=
FP16_BLACK_LIST
elif
dtype
==
'bfloat16'
:
_white_list
=
BF16_WHITE_LIST
_black_list
=
BF16_BLACK_LIST
elif
level
==
'O2'
:
amp_level
=
AMP_LEVEL
.
O2
if
dtype
==
'float16'
:
_white_list
=
PURE_FP16_WHITE_LIST
_black_list
=
PURE_FP16_BLACK_LIST
elif
dtype
==
'bfloat16'
:
_white_list
=
BF16_WHITE_LIST
_black_list
=
BF16_BLACK_LIST
elif
level
==
'O0'
:
amp_level
=
AMP_LEVEL
.
O0
if
dtype
==
'float16'
:
_white_list
=
FP16_WHITE_LIST
_black_list
=
FP16_BLACK_LIST
elif
dtype
==
'bfloat16'
:
_white_list
=
BF16_WHITE_LIST
_black_list
=
BF16_BLACK_LIST
if
custom_white_list
or
custom_black_list
:
_white_list
,
_black_list
=
_update_list
(
custom_white_list
,
custom_black_list
,
level
,
dtype
)
_white_list
,
_black_list
=
_update_list
(
custom_white_list
,
custom_black_list
,
level
,
dtype
)
if
not
enable
:
amp_level
=
AMP_LEVEL
.
O0
...
...
python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py
浏览文件 @
41e90283
...
...
@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase):
def
custom_op_list
(
self
):
with
fluid
.
dygraph
.
guard
():
tracer
=
fluid
.
framework
.
_dygraph_tracer
()
base_white_list
=
paddle
.
amp
.
FP16_WHITE_LIST
base_black_list
=
paddle
.
amp
.
FP16_BLACK_LIST
base_white_list
=
paddle
.
amp
.
white_list
()[
"float16"
][
"O1"
]
base_black_list
=
paddle
.
amp
.
black_list
()[
"float16"
][
"O1"
]
with
paddle
.
amp
.
amp_guard
(
custom_white_list
=
[
"log"
],
custom_black_list
=
[
"conv2d"
]
):
...
...
@@ -104,8 +104,8 @@ class TestAutoCast(unittest.TestCase):
==
(
set
(
base_black_list
)
-
{
"log"
})
|
{
"conv2d"
}
)
base_white_list
=
paddle
.
amp
.
PURE_FP16_WHITE_LIST
base_black_list
=
paddle
.
amp
.
PURE_FP16_BLACK_LIST
base_white_list
=
paddle
.
amp
.
white_list
()[
"float16"
][
"O2"
]
base_black_list
=
paddle
.
amp
.
black_list
()[
"float16"
][
"O2"
]
with
paddle
.
amp
.
amp_guard
(
custom_white_list
=
[
"log"
],
custom_black_list
=
[
"conv2d"
],
...
...
test/amp/test_amp_list.py
浏览文件 @
41e90283
...
...
@@ -14,32 +14,63 @@
import
unittest
import
paddle
from
paddle.fluid
import
core
from
paddle.static.amp
import
fp16_lists
from
paddle.static.amp.fp16_lists
import
AutoMixedPrecisionLists
from
paddle.static.amp
import
AutoMixedPrecisionLists
,
fp16_lists
class
TestAMPList
(
unittest
.
TestCase
):
def
test_main
(
self
):
custom_white_list
=
[
'lookup_table'
,
'lookup_table_v2'
,
]
amp_list
=
AutoMixedPrecisionLists
(
custom_white_list
=
custom_white_list
)
for
op
in
custom_white_list
:
self
.
assertTrue
(
op
in
amp_list
.
white_list
)
self
.
assertTrue
(
op
not
in
amp_list
.
black_list
)
self
.
assertTrue
(
op
not
in
amp_list
.
unsupported_list
)
default_black_list
=
[
def
setUp
(
self
):
self
.
default_black_list
=
[
'linear_interp_v2'
,
'nearest_interp_v2'
,
'bilinear_interp_v2'
,
'bicubic_interp_v2'
,
'trilinear_interp_v2'
,
]
for
op
in
default_black_list
:
self
.
assertTrue
(
op
in
amp_list
.
black_list
)
self
.
custom_white_list
=
[
'lookup_table'
,
'lookup_table_v2'
,
]
def
check_if_op_in_list
(
self
,
op_list
,
amp_list
):
for
op
in
op_list
:
self
.
assertTrue
(
op
in
amp_list
)
def
check_if_op_not_in_list
(
self
,
op_list
,
amp_list
):
for
op
in
op_list
:
self
.
assertTrue
(
op
not
in
amp_list
)
def
test_static
(
self
):
amp_list
=
AutoMixedPrecisionLists
(
custom_white_list
=
self
.
custom_white_list
)
self
.
check_if_op_in_list
(
self
.
default_black_list
,
amp_list
.
black_list
)
self
.
check_if_op_in_list
(
self
.
custom_white_list
,
amp_list
.
white_list
)
self
.
check_if_op_not_in_list
(
self
.
custom_white_list
,
amp_list
.
black_list
)
self
.
check_if_op_not_in_list
(
self
.
custom_white_list
,
amp_list
.
unsupported_list
)
def
test_eager
(
self
):
if
not
paddle
.
amp
.
is_float16_supported
():
return
white_list
=
paddle
.
amp
.
white_list
()
black_list
=
paddle
.
amp
.
black_list
()
self
.
check_if_op_in_list
(
self
.
default_black_list
,
black_list
[
"float16"
][
"O2"
]
)
self
.
check_if_op_not_in_list
([
'log'
,
'elementwise_add'
],
white_list
)
with
paddle
.
amp
.
auto_cast
(
custom_white_list
=
{
'elementwise_add'
}):
out1
=
paddle
.
rand
([
2
,
3
])
+
paddle
.
rand
([
2
,
3
])
out2
=
out1
.
mean
()
out3
=
paddle
.
log
(
out2
)
self
.
check_if_op_not_in_list
([
'log'
,
'elementwise_add'
],
white_list
)
self
.
assertEqual
(
out1
.
dtype
,
paddle
.
float16
)
self
.
assertEqual
(
out2
.
dtype
,
paddle
.
float32
)
self
.
assertEqual
(
out3
.
dtype
,
paddle
.
float32
)
def
test_apis
(
self
):
def
_run_check_dtype
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录