Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
befeaeb5
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
befeaeb5
编写于
8月 04, 2022
作者:
S
shangliang Xu
提交者:
GitHub
8月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] add white and black list for amp train (#6576)
上级
3e4d5697
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
34 addition
and
9 deletion
+34
-9
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+17
-6
ppdet/optimizer/ema.py
ppdet/optimizer/ema.py
+4
-1
ppdet/utils/checkpoint.py
ppdet/utils/checkpoint.py
+13
-2
未找到文件。
ppdet/engine/trainer.py
浏览文件 @
befeaeb5
...
@@ -69,6 +69,8 @@ class Trainer(object):
...
@@ -69,6 +69,8 @@ class Trainer(object):
self
.
is_loaded_weights
=
False
self
.
is_loaded_weights
=
False
self
.
use_amp
=
self
.
cfg
.
get
(
'amp'
,
False
)
self
.
use_amp
=
self
.
cfg
.
get
(
'amp'
,
False
)
self
.
amp_level
=
self
.
cfg
.
get
(
'amp_level'
,
'O1'
)
self
.
amp_level
=
self
.
cfg
.
get
(
'amp_level'
,
'O1'
)
self
.
custom_white_list
=
self
.
cfg
.
get
(
'custom_white_list'
,
None
)
self
.
custom_black_list
=
self
.
cfg
.
get
(
'custom_black_list'
,
None
)
# build data loader
# build data loader
capital_mode
=
self
.
mode
.
capitalize
()
capital_mode
=
self
.
mode
.
capitalize
()
...
@@ -155,8 +157,10 @@ class Trainer(object):
...
@@ -155,8 +157,10 @@ class Trainer(object):
self
.
pruner
=
create
(
'UnstructuredPruner'
)(
self
.
model
,
self
.
pruner
=
create
(
'UnstructuredPruner'
)(
self
.
model
,
steps_per_epoch
)
steps_per_epoch
)
if
self
.
use_amp
and
self
.
amp_level
==
'O2'
:
if
self
.
use_amp
and
self
.
amp_level
==
'O2'
:
self
.
model
=
paddle
.
amp
.
decorate
(
self
.
model
,
self
.
optimizer
=
paddle
.
amp
.
decorate
(
models
=
self
.
model
,
level
=
self
.
amp_level
)
models
=
self
.
model
,
optimizers
=
self
.
optimizer
,
level
=
self
.
amp_level
)
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
if
self
.
use_ema
:
if
self
.
use_ema
:
ema_decay
=
self
.
cfg
.
get
(
'ema_decay'
,
0.9998
)
ema_decay
=
self
.
cfg
.
get
(
'ema_decay'
,
0.9998
)
...
@@ -456,7 +460,9 @@ class Trainer(object):
...
@@ -456,7 +460,9 @@ class Trainer(object):
DataParallel
)
and
use_fused_allreduce_gradients
:
DataParallel
)
and
use_fused_allreduce_gradients
:
with
model
.
no_sync
():
with
model
.
no_sync
():
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
self
.
cfg
.
use_gpus
,
enable
=
self
.
cfg
.
use_gpu
,
custom_white_list
=
self
.
custom_white_list
,
custom_black_list
=
self
.
custom_black_list
,
level
=
self
.
amp_level
):
level
=
self
.
amp_level
):
# model forward
# model forward
outputs
=
model
(
data
)
outputs
=
model
(
data
)
...
@@ -468,7 +474,10 @@ class Trainer(object):
...
@@ -468,7 +474,10 @@ class Trainer(object):
list
(
model
.
parameters
()),
None
)
list
(
model
.
parameters
()),
None
)
else
:
else
:
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
self
.
cfg
.
use_gpu
,
level
=
self
.
amp_level
):
enable
=
self
.
cfg
.
use_gpu
,
custom_white_list
=
self
.
custom_white_list
,
custom_black_list
=
self
.
custom_black_list
,
level
=
self
.
amp_level
):
# model forward
# model forward
outputs
=
model
(
data
)
outputs
=
model
(
data
)
loss
=
outputs
[
'loss'
]
loss
=
outputs
[
'loss'
]
...
@@ -477,7 +486,6 @@ class Trainer(object):
...
@@ -477,7 +486,6 @@ class Trainer(object):
scaled_loss
.
backward
()
scaled_loss
.
backward
()
# in dygraph mode, optimizer.minimize is equal to optimizer.step
# in dygraph mode, optimizer.minimize is equal to optimizer.step
scaler
.
minimize
(
self
.
optimizer
,
scaled_loss
)
scaler
.
minimize
(
self
.
optimizer
,
scaled_loss
)
else
:
else
:
if
isinstance
(
if
isinstance
(
model
,
paddle
.
model
,
paddle
.
...
@@ -575,7 +583,10 @@ class Trainer(object):
...
@@ -575,7 +583,10 @@ class Trainer(object):
# forward
# forward
if
self
.
use_amp
:
if
self
.
use_amp
:
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
self
.
cfg
.
use_gpu
,
level
=
self
.
amp_level
):
enable
=
self
.
cfg
.
use_gpu
,
custom_white_list
=
self
.
custom_white_list
,
custom_black_list
=
self
.
custom_black_list
,
level
=
self
.
amp_level
):
outs
=
self
.
model
(
data
)
outs
=
self
.
model
(
data
)
else
:
else
:
outs
=
self
.
model
(
data
)
outs
=
self
.
model
(
data
)
...
...
ppdet/optimizer/ema.py
浏览文件 @
befeaeb5
...
@@ -66,7 +66,10 @@ class ModelEMA(object):
...
@@ -66,7 +66,10 @@ class ModelEMA(object):
def
resume
(
self
,
state_dict
,
step
=
0
):
def
resume
(
self
,
state_dict
,
step
=
0
):
for
k
,
v
in
state_dict
.
items
():
for
k
,
v
in
state_dict
.
items
():
if
k
in
self
.
state_dict
:
if
k
in
self
.
state_dict
:
if
self
.
state_dict
[
k
].
dtype
==
v
.
dtype
:
self
.
state_dict
[
k
]
=
v
self
.
state_dict
[
k
]
=
v
else
:
self
.
state_dict
[
k
]
=
v
.
astype
(
self
.
state_dict
[
k
].
dtype
)
self
.
step
=
step
self
.
step
=
step
def
update
(
self
,
model
=
None
):
def
update
(
self
,
model
=
None
):
...
...
ppdet/utils/checkpoint.py
浏览文件 @
befeaeb5
...
@@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
...
@@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
model_weight
=
{}
model_weight
=
{}
incorrect_keys
=
0
incorrect_keys
=
0
for
key
in
model_dict
.
key
s
():
for
key
,
value
in
model_dict
.
item
s
():
if
key
in
param_state_dict
.
keys
():
if
key
in
param_state_dict
.
keys
():
if
isinstance
(
param_state_dict
[
key
],
np
.
ndarray
):
param_state_dict
[
key
]
=
paddle
.
to_tensor
(
param_state_dict
[
key
])
if
value
.
dtype
==
param_state_dict
[
key
].
dtype
:
model_weight
[
key
]
=
param_state_dict
[
key
]
model_weight
[
key
]
=
param_state_dict
[
key
]
else
:
model_weight
[
key
]
=
param_state_dict
[
key
].
astype
(
value
.
dtype
)
else
:
else
:
logger
.
info
(
'Unmatched key: {}'
.
format
(
key
))
logger
.
info
(
'Unmatched key: {}'
.
format
(
key
))
incorrect_keys
+=
1
incorrect_keys
+=
1
...
@@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight):
...
@@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight):
param_state_dict
=
paddle
.
load
(
weights_path
)
param_state_dict
=
paddle
.
load
(
weights_path
)
param_state_dict
=
match_state_dict
(
model_dict
,
param_state_dict
)
param_state_dict
=
match_state_dict
(
model_dict
,
param_state_dict
)
for
k
,
v
in
param_state_dict
.
items
():
if
isinstance
(
v
,
np
.
ndarray
):
v
=
paddle
.
to_tensor
(
v
)
if
model_dict
[
k
].
dtype
!=
v
.
dtype
:
param_state_dict
[
k
]
=
v
.
astype
(
model_dict
[
k
].
dtype
)
model
.
set_dict
(
param_state_dict
)
model
.
set_dict
(
param_state_dict
)
logger
.
info
(
'Finish loading model weights: {}'
.
format
(
weights_path
))
logger
.
info
(
'Finish loading model weights: {}'
.
format
(
weights_path
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录