Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e29c2d12
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e29c2d12
编写于
8月 16, 2021
作者:
L
Leo Chen
提交者:
GitHub
8月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[amp] dygraph amp support param_group (#34899)
* dygraph amp support param_group * remove unused code * fix doc
上级
b0cb4148
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
100 addition
and
19 deletion
+100
-19
python/paddle/amp/grad_scaler.py
python/paddle/amp/grad_scaler.py
+43
-0
python/paddle/fluid/dygraph/amp/loss_scaler.py
python/paddle/fluid/dygraph/amp/loss_scaler.py
+13
-4
python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py
...d/tests/unittests/test_imperative_auto_mixed_precision.py
+44
-15
未找到文件。
python/paddle/amp/grad_scaler.py
浏览文件 @
e29c2d12
...
...
@@ -146,6 +146,49 @@ class GradScaler(AmpScaler):
"""
return
super
(
GradScaler
,
self
).
minimize
(
optimizer
,
*
args
,
**
kwargs
)
def
step
(
self
,
optimizer
):
"""
This function is similar as `optimizer.step()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Examples:
.. code-block:: python
# required: gpu
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.auto_cast():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.step(optimizer)
optimizer.clear_grad()
"""
if
not
self
.
_enable
:
return
optimizer
.
step
()
# unscale the grad
self
.
_unscale
(
optimizer
)
if
self
.
_found_inf
:
self
.
_cache_founf_inf
=
True
else
:
optimizer
.
step
()
self
.
_cache_founf_inf
=
False
if
self
.
_use_dynamic_loss_scaling
:
# uopdate the scale
self
.
_update
()
def
is_enable
(
self
):
"""
Enable loss scaling or not.
...
...
python/paddle/fluid/dygraph/amp/loss_scaler.py
浏览文件 @
e29c2d12
...
...
@@ -212,6 +212,15 @@ class AmpScaler(object):
def
_unscale
(
self
,
optimizer
):
if
not
self
.
_enable
:
return
if
getattr
(
optimizer
,
'_param_groups'
,
None
)
and
isinstance
(
optimizer
.
_param_groups
[
0
],
dict
):
param_grads
=
[]
for
group
in
optimizer
.
_param_groups
:
for
param
in
group
[
'params'
]:
if
param
.
_grad_ivar
()
is
not
None
:
param_grads
.
append
(
param
.
_grad_ivar
())
else
:
param_grads
=
[
param
.
_grad_ivar
()
for
param
in
optimizer
.
_parameter_list
if
param
.
_grad_ivar
()
is
not
None
...
...
python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py
浏览文件 @
e29c2d12
...
...
@@ -19,6 +19,9 @@ import numpy as np
import
six
from
test_imperative_resnet
import
ResNet
,
BottleneckBlock
,
ConvBNLayer
,
train_parameters
,
optimizer_setting
if
fluid
.
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
class
SimpleConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
...
...
@@ -373,8 +376,6 @@ class TestGradScalerStateDict(unittest.TestCase):
return
dy_out
,
dy_param_value
,
dy_grad_value
def
test_with_state_dict
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
with
fluid
.
dygraph
.
guard
():
out_use_state_dict
=
self
.
train_resnet
(
enable_amp
=
True
,
use_data_loader
=
True
,
use_save_load
=
True
)
...
...
@@ -390,18 +391,43 @@ class TestResnet2(unittest.TestCase):
Use paddle-2.0 API
"""
def
train_resnet
(
self
,
enable_amp
=
True
,
use_data_loader
=
False
):
def
train_resnet
(
self
,
enable_amp
=
True
,
use_data_loader
=
False
,
use_param_group
=
False
):
seed
=
90
batch_size
=
train_parameters
[
"batch_size"
]
batch_num
=
1
batch_num
=
1
0
paddle
.
seed
(
seed
)
paddle
.
framework
.
random
.
_manual_program_seed
(
seed
)
resnet
=
ResNet
(
use_cudnn
=
True
)
optimizer
=
optimizer_setting
(
train_parameters
,
parameter_list
=
resnet
.
parameters
())
if
use_param_group
:
conv_params
=
resnet
.
conv
.
parameters
()
other_params
=
[]
for
p
in
resnet
.
parameters
():
contains
=
False
for
q
in
conv_params
:
if
p
is
q
:
contains
=
True
if
not
contains
:
other_params
.
append
(
p
)
# NOTE(zhiqiu): The Membership test operations(in / not in) calls "is" and "equal",
# see details: https://docs.python.org/3/reference/expressions.html#membership-test-operations.
# So do not use other_params = [p for p in resnet.parameters() if p not in conv_params]
optimizer
=
paddle
.
optimizer
.
Momentum
(
parameters
=
[{
'params'
:
conv_params
,
'learning_rate'
:
0.01
},
{
'params'
:
other_params
,
'learning_rate'
:
0.001
}])
else
:
optimizer
=
paddle
.
optimizer
.
SGD
(
parameters
=
resnet
.
parameters
())
np
.
random
.
seed
(
seed
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
flowers
.
train
(
use_xmap
=
False
),
batch_size
=
batch_size
)
...
...
@@ -456,7 +482,7 @@ class TestResnet2(unittest.TestCase):
scaled_loss
=
scaler
.
scale
(
avg_loss
)
scaled_loss
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled_loss
)
scaler
.
step
(
optimizer
)
dy_grad_value
=
{}
for
param
in
resnet
.
parameters
():
...
...
@@ -475,22 +501,27 @@ class TestResnet2(unittest.TestCase):
return
dy_out
,
dy_param_value
,
dy_grad_value
def
test_resnet
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
with
fluid
.
dygraph
.
guard
():
out_fp32
=
self
.
train_resnet
(
enable_amp
=
False
)
out_amp
=
self
.
train_resnet
(
enable_amp
=
True
)
print
(
out_fp32
[
0
],
out_amp
[
0
])
self
.
assertTrue
(
np
.
allclose
(
out_fp32
[
0
],
out_amp
[
0
],
atol
=
1.e-
2
))
self
.
assertTrue
(
np
.
allclose
(
out_fp32
[
0
],
out_amp
[
0
],
atol
=
1.e-
5
))
def
test_with_data_loader
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
with
fluid
.
dygraph
.
guard
():
out_fp32
=
self
.
train_resnet
(
enable_amp
=
False
,
use_data_loader
=
True
)
out_amp
=
self
.
train_resnet
(
enable_amp
=
True
,
use_data_loader
=
True
)
print
(
out_fp32
[
0
],
out_amp
[
0
])
self
.
assertTrue
(
np
.
allclose
(
out_fp32
[
0
],
out_amp
[
0
],
atol
=
1.e-2
))
self
.
assertTrue
(
np
.
allclose
(
out_fp32
[
0
],
out_amp
[
0
],
atol
=
1.e-5
))
def
test_param_group
(
self
):
with
fluid
.
dygraph
.
guard
():
out_fp32
=
self
.
train_resnet
(
enable_amp
=
False
,
use_data_loader
=
True
,
use_param_group
=
True
)
out_amp
=
self
.
train_resnet
(
enable_amp
=
True
,
use_data_loader
=
True
,
use_param_group
=
True
)
print
(
out_fp32
[
0
],
out_amp
[
0
])
self
.
assertTrue
(
np
.
allclose
(
out_fp32
[
0
],
out_amp
[
0
],
atol
=
1.e-5
))
class
TestResnet
(
unittest
.
TestCase
):
...
...
@@ -566,8 +597,6 @@ class TestResnet(unittest.TestCase):
return
dy_out
,
dy_param_value
,
dy_grad_value
def
test_resnet
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
out_fp32
=
self
.
train_resnet
(
enable_amp
=
False
)
out_amp
=
self
.
train_resnet
(
enable_amp
=
True
)
print
(
out_fp32
[
0
],
out_amp
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录