Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c2c3bd43
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c2c3bd43
编写于
5月 16, 2023
作者:
N
niuliling123
提交者:
GitHub
5月 16, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AMP] support OD level for static (#53768)
上级
52889e38
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
113 addition
and
27 deletion
+113
-27
python/paddle/static/amp/decorator.py
python/paddle/static/amp/decorator.py
+25
-19
python/paddle/static/amp/fp16_lists.py
python/paddle/static/amp/fp16_lists.py
+10
-6
python/paddle/static/amp/fp16_utils.py
python/paddle/static/amp/fp16_utils.py
+9
-1
test/amp/test_amp_api.py
test/amp/test_amp_api.py
+69
-1
未找到文件。
python/paddle/static/amp/decorator.py
浏览文件 @
c2c3bd43
...
@@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision:
...
@@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision:
Args:
Args:
optimizer (Optimizer): A common Optimizer object.
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are
level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
"O1" and "O2": O1 represent mixed precision, the input data type
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
of each operator will be casted by white_list and black_list;
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
O2 represent Pure fp16 or bf16, all operators parameters and input
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
data will be casted to fp16 or bf16, except operators in black_list,
default white list will compute in float16/bfloat16.
don't support fp16 or bf16 kernel and batch_norm.
dtype(str): Whether to use 'float16' or 'bfloat16'.
dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
...
@@ -123,6 +122,7 @@ class OptimizerWithMixedPrecision:
...
@@ -123,6 +122,7 @@ class OptimizerWithMixedPrecision:
self
.
_learning_rate
=
optimizer
.
_learning_rate
self
.
_learning_rate
=
optimizer
.
_learning_rate
self
.
_learning_rate_map
=
optimizer
.
_learning_rate_map
self
.
_learning_rate_map
=
optimizer
.
_learning_rate_map
self
.
_use_pure_fp16
=
level
==
"O2"
self
.
_use_pure_fp16
=
level
==
"O2"
self
.
_amp_level
=
level
self
.
_use_fp16_guard
=
use_amp_guard
self
.
_use_fp16_guard
=
use_amp_guard
self
.
_to_fp16_var_names
=
None
self
.
_to_fp16_var_names
=
None
if
self
.
_use_dynamic_loss_scaling
:
if
self
.
_use_dynamic_loss_scaling
:
...
@@ -241,7 +241,7 @@ class OptimizerWithMixedPrecision:
...
@@ -241,7 +241,7 @@ class OptimizerWithMixedPrecision:
self
.
_amp_lists
,
self
.
_amp_lists
,
use_fp16_guard
=
False
,
use_fp16_guard
=
False
,
dest_type
=
self
.
_amp_vartype
,
dest_type
=
self
.
_amp_vartype
,
level
=
'O1'
,
level
=
self
.
_amp_level
,
use_promote
=
self
.
use_promote
,
use_promote
=
self
.
use_promote
,
)
)
...
@@ -380,7 +380,7 @@ class OptimizerWithMixedPrecision:
...
@@ -380,7 +380,7 @@ class OptimizerWithMixedPrecision:
self
.
_amp_lists
,
self
.
_amp_lists
,
use_fp16_guard
=
False
,
use_fp16_guard
=
False
,
dest_type
=
self
.
_amp_vartype
,
dest_type
=
self
.
_amp_vartype
,
level
=
'O1'
,
level
=
self
.
_amp_level
,
use_promote
=
self
.
use_promote
,
use_promote
=
self
.
use_promote
,
)
)
...
@@ -773,12 +773,11 @@ def decorate(
...
@@ -773,12 +773,11 @@ def decorate(
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
white_list and black_list will be used for AMP training when it is
not set. Default is None.
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
"O1" and "O2": O1 represent mixed precision, the input data type of
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
each operator will be casted by white_list and black_list;
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
O2 represent pure FP16 / BF16 training, all operators parameters
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
and input data will be casted to FP16 / BF16, except operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
during weight updating. If master_weight is None, in O2 level optimizer
...
@@ -847,15 +846,22 @@ def decorate(
...
@@ -847,15 +846,22 @@ def decorate(
"""
"""
# check amp_level: O0-O2
# check amp_level: O0-O2
level
=
level
.
upper
()
level
=
level
.
upper
()
if
not
(
level
in
[
'O0'
,
'O1'
,
'O2'
]):
if
not
(
level
in
[
'O0'
,
'OD'
,
'O1'
,
'O2'
]):
raise
ValueError
(
raise
ValueError
(
"level should be O0, OD, O1 or O2."
)
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
amp_dtype
=
check_amp_dtype
(
dtype
)
amp_dtype
=
check_amp_dtype
(
dtype
)
if
amp_lists
is
None
:
if
amp_lists
is
None
or
level
==
'OD'
:
amp_lists
=
AutoMixedPrecisionLists
(
dtype
=
amp_dtype
)
amp_lists
=
AutoMixedPrecisionLists
(
dtype
=
amp_dtype
)
if
level
==
'OD'
:
if
amp_lists
is
not
None
:
warnings
.
warn
(
"If the Amp level is set to OD, the amp list will not be used."
)
amp_lists
.
white_list
=
{
"conv2d"
,
"matmul_v2"
}
amp_lists
.
black_list
=
amp_lists
.
all_list
-
amp_lists
.
white_list
if
use_dynamic_loss_scaling
is
None
:
if
use_dynamic_loss_scaling
is
None
:
use_dynamic_loss_scaling
=
dtype
==
"float16"
use_dynamic_loss_scaling
=
dtype
==
"float16"
...
...
python/paddle/static/amp/fp16_lists.py
浏览文件 @
c2c3bd43
...
@@ -99,7 +99,7 @@ def _get_sys_unsupported_list(dtype):
...
@@ -99,7 +99,7 @@ def _get_sys_unsupported_list(dtype):
device
=
'XPU'
device
=
'XPU'
else
:
else
:
device
=
'GPU'
device
=
'GPU'
_
,
_
,
sys_unsupported_list
=
core
.
op_supported_infos
(
device
,
var_type
)
all_ops
,
_
,
sys_unsupported_list
=
core
.
op_supported_infos
(
device
,
var_type
)
# sys_unsupported_list will include the following ops.
# sys_unsupported_list will include the following ops.
supported_fp16_list
=
{
supported_fp16_list
=
{
...
@@ -114,13 +114,13 @@ def _get_sys_unsupported_list(dtype):
...
@@ -114,13 +114,13 @@ def _get_sys_unsupported_list(dtype):
}
}
sys_unsupported_list
-=
supported_fp16_list
sys_unsupported_list
-=
supported_fp16_list
return
device
,
sys_unsupported_list
return
device
,
sys_unsupported_list
,
all_ops
def
_get_unsupported_list
(
dtype
):
def
_get_unsupported_list
(
dtype
):
# The set of ops that don't support fp16 calculation
# The set of ops that don't support fp16 calculation
_
,
_sys_unsupported_list
=
_get_sys_unsupported_list
(
dtype
)
_
,
_sys_unsupported_list
,
_sys_all_list
=
_get_sys_unsupported_list
(
dtype
)
return
_sys_unsupported_list
return
_sys_unsupported_list
,
_sys_all_list
# The three sets listed below are changed dynamiclly. They don't contain all
# The three sets listed below are changed dynamiclly. They don't contain all
...
@@ -200,7 +200,9 @@ class AutoMixedPrecisionLists:
...
@@ -200,7 +200,9 @@ class AutoMixedPrecisionLists:
self
.
white_list
=
copy
.
copy
(
_get_white_list
(
self
.
amp_dtype
))
self
.
white_list
=
copy
.
copy
(
_get_white_list
(
self
.
amp_dtype
))
self
.
black_list
=
copy
.
copy
(
_get_black_list
())
self
.
black_list
=
copy
.
copy
(
_get_black_list
())
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
unsupported_list
=
copy
.
copy
(
_get_unsupported_list
(
self
.
amp_dtype
))
unsupported_list
,
sys_all_list
=
_get_unsupported_list
(
self
.
amp_dtype
)
self
.
unsupported_list
=
copy
.
copy
(
unsupported_list
)
self
.
all_list
=
copy
.
copy
(
sys_all_list
)
self
.
black_varnames
=
copy
.
copy
(
custom_black_varnames
)
self
.
black_varnames
=
copy
.
copy
(
custom_black_varnames
)
self
.
_update_list
()
self
.
_update_list
()
...
@@ -232,7 +234,9 @@ class AutoMixedPrecisionLists:
...
@@ -232,7 +234,9 @@ class AutoMixedPrecisionLists:
self
.
gray_list
.
remove
(
op_name
)
self
.
gray_list
.
remove
(
op_name
)
self
.
black_list
.
add
(
op_name
)
self
.
black_list
.
add
(
op_name
)
self
.
unsupported_list
.
add
(
op_name
)
self
.
unsupported_list
.
add
(
op_name
)
device
,
sys_unsupported_list
=
_get_sys_unsupported_list
(
self
.
amp_dtype
)
device
,
sys_unsupported_list
,
_
=
_get_sys_unsupported_list
(
self
.
amp_dtype
)
actual_unsupported_list
=
[]
actual_unsupported_list
=
[]
for
op_name
in
sys_unsupported_list
:
for
op_name
in
sys_unsupported_list
:
if
op_name
in
self
.
white_list
:
if
op_name
in
self
.
white_list
:
...
...
python/paddle/static/amp/fp16_utils.py
浏览文件 @
c2c3bd43
...
@@ -426,7 +426,7 @@ def set_var_dst_dtype(
...
@@ -426,7 +426,7 @@ def set_var_dst_dtype(
def
set_param_dtype
(
program
,
dtype
,
amp_lists
,
use_fp16_guard
,
level
):
def
set_param_dtype
(
program
,
dtype
,
amp_lists
,
use_fp16_guard
,
level
):
keep_fp32_var_names
=
set
()
keep_fp32_var_names
=
set
()
if
level
==
"O1"
:
if
level
==
"O1"
or
level
==
"OD"
:
return
keep_fp32_var_names
return
keep_fp32_var_names
all_parameters
=
[]
all_parameters
=
[]
for
block
in
program
.
blocks
:
for
block
in
program
.
blocks
:
...
@@ -618,6 +618,14 @@ def cast_model_to_fp16(
...
@@ -618,6 +618,14 @@ def cast_model_to_fp16(
if
level
==
'O2'
:
if
level
==
'O2'
:
amp_lists
.
black_list
=
amp_lists
.
black_list
-
black_list
amp_lists
.
black_list
=
amp_lists
.
black_list
-
black_list
if
level
==
'OD'
:
if
amp_lists
is
not
None
:
dtype
=
get_low_precision_dtypestr
(
dest_type
)
amp_lists
=
AutoMixedPrecisionLists
(
dtype
)
amp_lists
.
white_list
=
{
"conv2d"
,
"matmul_v2"
}
amp_lists
.
black_list
=
amp_lists
.
all_list
-
amp_lists
.
white_list
global_block
=
program
.
global_block
()
global_block
=
program
.
global_block
()
keep_fp32_ops
=
set
()
keep_fp32_ops
=
set
()
keep_fp16_ops
=
set
()
keep_fp16_ops
=
set
()
...
...
test/amp/test_amp_api.py
浏览文件 @
c2c3bd43
...
@@ -14,9 +14,11 @@
...
@@ -14,9 +14,11 @@
import
unittest
import
unittest
from
amp_base_models
import
AmpTestBase
import
numpy
as
np
from
amp_base_models
import
AmpTestBase
,
build_conv_model
import
paddle
import
paddle
from
paddle.static
import
amp
class
TestAutoCast
(
AmpTestBase
):
class
TestAutoCast
(
AmpTestBase
):
...
@@ -37,6 +39,72 @@ class TestAutoCast(AmpTestBase):
...
@@ -37,6 +39,72 @@ class TestAutoCast(AmpTestBase):
self
.
assertEqual
(
out3
.
dtype
,
paddle
.
float32
)
self
.
assertEqual
(
out3
.
dtype
,
paddle
.
float32
)
class
TestStaticDecorate
(
AmpTestBase
):
def
check_results
(
self
,
use_amp
,
dtype
,
level
,
use_promote
,
expected_op_calls
):
(
main_program
,
startup_program
,
optimizer
,
feed_vars
,
fetch_vars
,
)
=
build_conv_model
(
use_amp
,
dtype
,
level
,
use_promote
)
self
.
assertEqual
(
main_program
.
num_blocks
,
1
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Adadelta
(
learning_rate
=
0.001
)
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
level
=
level
,
)
amp
.
debugging
.
collect_operator_stats
(
main_program
)
op_stats_list
=
amp
.
debugging
.
_get_op_stats_list
(
main_program
)
self
.
_check_op_calls
(
op_stats_list
[
0
],
expected_fp16_calls
=
expected_op_calls
)
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
max_iters
=
2
x_fp32
=
np
.
random
.
random
(
size
=
[
1
,
1
,
6
,
6
]).
astype
(
"float32"
)
losses_o1
=
self
.
run_program
(
main_program
,
startup_program
,
optimizer
,
feed_vars
,
fetch_vars
,
place
,
exe
,
x_fp32
,
max_iters
,
level
,
)
def
test_static_amp_o1
(
self
):
paddle
.
enable_static
()
expected_fp16_calls
=
{
"conv2d"
:
1
,
"elementwise_add"
:
0
,
"relu"
:
0
,
"matmul_v2"
:
1
,
"softmax"
:
0
,
"reduce_mean"
:
0
,
"adamw"
:
0
,
}
self
.
check_results
(
True
,
'float16'
,
'OD'
,
use_promote
=
True
,
expected_op_calls
=
expected_fp16_calls
,
)
paddle
.
disable_static
()
class
TestGradScaler
(
AmpTestBase
):
class
TestGradScaler
(
AmpTestBase
):
def
test_amp_grad_scaler
(
self
):
def
test_amp_grad_scaler
(
self
):
model
=
paddle
.
nn
.
Conv2D
(
3
,
2
,
3
)
model
=
paddle
.
nn
.
Conv2D
(
3
,
2
,
3
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录