Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
befeaeb5
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
befeaeb5
编写于
8月 04, 2022
作者:
S
shangliang Xu
提交者:
GitHub
8月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] add white and black list for amp train (#6576)
上级
3e4d5697
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
34 addition
and
9 deletion
+34
-9
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+17
-6
ppdet/optimizer/ema.py
ppdet/optimizer/ema.py
+4
-1
ppdet/utils/checkpoint.py
ppdet/utils/checkpoint.py
+13
-2
未找到文件。
ppdet/engine/trainer.py
浏览文件 @
befeaeb5
...
...
@@ -69,6 +69,8 @@ class Trainer(object):
self
.
is_loaded_weights
=
False
self
.
use_amp
=
self
.
cfg
.
get
(
'amp'
,
False
)
self
.
amp_level
=
self
.
cfg
.
get
(
'amp_level'
,
'O1'
)
self
.
custom_white_list
=
self
.
cfg
.
get
(
'custom_white_list'
,
None
)
self
.
custom_black_list
=
self
.
cfg
.
get
(
'custom_black_list'
,
None
)
# build data loader
capital_mode
=
self
.
mode
.
capitalize
()
...
...
@@ -155,8 +157,10 @@ class Trainer(object):
self
.
pruner
=
create
(
'UnstructuredPruner'
)(
self
.
model
,
steps_per_epoch
)
if
self
.
use_amp
and
self
.
amp_level
==
'O2'
:
self
.
model
=
paddle
.
amp
.
decorate
(
models
=
self
.
model
,
level
=
self
.
amp_level
)
self
.
model
,
self
.
optimizer
=
paddle
.
amp
.
decorate
(
models
=
self
.
model
,
optimizers
=
self
.
optimizer
,
level
=
self
.
amp_level
)
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
if
self
.
use_ema
:
ema_decay
=
self
.
cfg
.
get
(
'ema_decay'
,
0.9998
)
...
...
@@ -456,7 +460,9 @@ class Trainer(object):
DataParallel
)
and
use_fused_allreduce_gradients
:
with
model
.
no_sync
():
with
paddle
.
amp
.
auto_cast
(
enable
=
self
.
cfg
.
use_gpus
,
enable
=
self
.
cfg
.
use_gpu
,
custom_white_list
=
self
.
custom_white_list
,
custom_black_list
=
self
.
custom_black_list
,
level
=
self
.
amp_level
):
# model forward
outputs
=
model
(
data
)
...
...
@@ -468,7 +474,10 @@ class Trainer(object):
list
(
model
.
parameters
()),
None
)
else
:
with
paddle
.
amp
.
auto_cast
(
enable
=
self
.
cfg
.
use_gpu
,
level
=
self
.
amp_level
):
enable
=
self
.
cfg
.
use_gpu
,
custom_white_list
=
self
.
custom_white_list
,
custom_black_list
=
self
.
custom_black_list
,
level
=
self
.
amp_level
):
# model forward
outputs
=
model
(
data
)
loss
=
outputs
[
'loss'
]
...
...
@@ -477,7 +486,6 @@ class Trainer(object):
scaled_loss
.
backward
()
# in dygraph mode, optimizer.minimize is equal to optimizer.step
scaler
.
minimize
(
self
.
optimizer
,
scaled_loss
)
else
:
if
isinstance
(
model
,
paddle
.
...
...
@@ -575,7 +583,10 @@ class Trainer(object):
# forward
if
self
.
use_amp
:
with
paddle
.
amp
.
auto_cast
(
enable
=
self
.
cfg
.
use_gpu
,
level
=
self
.
amp_level
):
enable
=
self
.
cfg
.
use_gpu
,
custom_white_list
=
self
.
custom_white_list
,
custom_black_list
=
self
.
custom_black_list
,
level
=
self
.
amp_level
):
outs
=
self
.
model
(
data
)
else
:
outs
=
self
.
model
(
data
)
...
...
ppdet/optimizer/ema.py
浏览文件 @
befeaeb5
...
...
@@ -66,7 +66,10 @@ class ModelEMA(object):
def
resume
(
self
,
state_dict
,
step
=
0
):
for
k
,
v
in
state_dict
.
items
():
if
k
in
self
.
state_dict
:
self
.
state_dict
[
k
]
=
v
if
self
.
state_dict
[
k
].
dtype
==
v
.
dtype
:
self
.
state_dict
[
k
]
=
v
else
:
self
.
state_dict
[
k
]
=
v
.
astype
(
self
.
state_dict
[
k
].
dtype
)
self
.
step
=
step
def
update
(
self
,
model
=
None
):
...
...
ppdet/utils/checkpoint.py
浏览文件 @
befeaeb5
...
...
@@ -84,9 +84,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
model_weight
=
{}
incorrect_keys
=
0
for
key
in
model_dict
.
key
s
():
for
key
,
value
in
model_dict
.
item
s
():
if
key
in
param_state_dict
.
keys
():
model_weight
[
key
]
=
param_state_dict
[
key
]
if
isinstance
(
param_state_dict
[
key
],
np
.
ndarray
):
param_state_dict
[
key
]
=
paddle
.
to_tensor
(
param_state_dict
[
key
])
if
value
.
dtype
==
param_state_dict
[
key
].
dtype
:
model_weight
[
key
]
=
param_state_dict
[
key
]
else
:
model_weight
[
key
]
=
param_state_dict
[
key
].
astype
(
value
.
dtype
)
else
:
logger
.
info
(
'Unmatched key: {}'
.
format
(
key
))
incorrect_keys
+=
1
...
...
@@ -209,6 +214,12 @@ def load_pretrain_weight(model, pretrain_weight):
param_state_dict
=
paddle
.
load
(
weights_path
)
param_state_dict
=
match_state_dict
(
model_dict
,
param_state_dict
)
for
k
,
v
in
param_state_dict
.
items
():
if
isinstance
(
v
,
np
.
ndarray
):
v
=
paddle
.
to_tensor
(
v
)
if
model_dict
[
k
].
dtype
!=
v
.
dtype
:
param_state_dict
[
k
]
=
v
.
astype
(
model_dict
[
k
].
dtype
)
model
.
set_dict
(
param_state_dict
)
logger
.
info
(
'Finish loading model weights: {}'
.
format
(
weights_path
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录