Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c36eb9f6
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c36eb9f6
编写于
4月 25, 2021
作者:
W
wangxinxin08
提交者:
GitHub
4月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
eval with ema weight while training (#2748)
上级
85a82d9b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
19 addition
and
17 deletion
+19
-17
ppdet/engine/callbacks.py
ppdet/engine/callbacks.py
+2
-17
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+17
-0
未找到文件。
ppdet/engine/callbacks.py
浏览文件 @
c36eb9f6
...
...
@@ -26,7 +26,6 @@ import paddle
import
paddle.distributed
as
dist
from
ppdet.utils.checkpoint
import
save_model
from
ppdet.optimizer
import
ModelEMA
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
'ppdet.engine'
)
...
...
@@ -143,20 +142,12 @@ class Checkpointer(Callback):
super
(
Checkpointer
,
self
).
__init__
(
model
)
cfg
=
self
.
model
.
cfg
self
.
best_ap
=
0.
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
self
.
save_dir
=
os
.
path
.
join
(
self
.
model
.
cfg
.
save_dir
,
self
.
model
.
cfg
.
filename
)
if
hasattr
(
self
.
model
.
model
,
'student_model'
):
self
.
weight
=
self
.
model
.
model
.
student_model
else
:
self
.
weight
=
self
.
model
.
model
if
self
.
use_ema
:
self
.
ema
=
ModelEMA
(
cfg
[
'ema_decay'
],
self
.
weight
,
use_thres_step
=
True
)
def
on_step_end
(
self
,
status
):
if
self
.
use_ema
:
self
.
ema
.
update
(
self
.
weight
)
def
on_epoch_end
(
self
,
status
):
# Checkpointer only performed during training
...
...
@@ -170,10 +161,7 @@ class Checkpointer(Callback):
if
epoch_id
%
self
.
model
.
cfg
.
snapshot_epoch
==
0
or
epoch_id
==
end_epoch
-
1
:
save_name
=
str
(
epoch_id
)
if
epoch_id
!=
end_epoch
-
1
else
"model_final"
if
self
.
use_ema
:
weight
=
self
.
ema
.
apply
()
else
:
weight
=
self
.
weight
weight
=
self
.
weight
elif
mode
==
'eval'
:
if
'save_best_model'
in
status
and
status
[
'save_best_model'
]:
for
metric
in
self
.
model
.
_metrics
:
...
...
@@ -187,10 +175,7 @@ class Checkpointer(Callback):
if
map_res
[
key
][
0
]
>
self
.
best_ap
:
self
.
best_ap
=
map_res
[
key
][
0
]
save_name
=
'best_model'
if
self
.
use_ema
:
weight
=
self
.
ema
.
apply
()
else
:
weight
=
self
.
weight
weight
=
self
.
weight
logger
.
info
(
"Best test {} ap is {:0.3f}."
.
format
(
key
,
self
.
best_ap
))
if
weight
:
...
...
ppdet/engine/trainer.py
浏览文件 @
c36eb9f6
...
...
@@ -28,6 +28,7 @@ import paddle.distributed as dist
from
paddle.distributed
import
fleet
from
paddle
import
amp
from
paddle.static
import
InputSpec
from
ppdet.optimizer
import
ModelEMA
from
ppdet.core.workspace
import
create
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
...
...
@@ -61,6 +62,11 @@ class Trainer(object):
self
.
model
=
self
.
cfg
.
model
self
.
is_loaded_weights
=
True
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
if
self
.
use_ema
:
self
.
ema
=
ModelEMA
(
cfg
[
'ema_decay'
],
self
.
model
,
use_thres_step
=
True
)
# build data loader
self
.
dataset
=
cfg
[
'{}Dataset'
.
format
(
self
.
mode
.
capitalize
())]
if
self
.
mode
==
'train'
:
...
...
@@ -281,8 +287,15 @@ class Trainer(object):
self
.
status
[
'batch_time'
].
update
(
time
.
time
()
-
iter_tic
)
self
.
_compose_callback
.
on_step_end
(
self
.
status
)
if
self
.
use_ema
:
self
.
ema
.
update
(
self
.
model
)
iter_tic
=
time
.
time
()
# apply ema weight on model
if
self
.
use_ema
:
weight
=
self
.
model
.
state_dict
()
self
.
model
.
set_dict
(
self
.
ema
.
apply
())
self
.
_compose_callback
.
on_epoch_end
(
self
.
status
)
if
validate
and
(
self
.
_nranks
<
2
or
self
.
_local_rank
==
0
)
\
...
...
@@ -303,6 +316,10 @@ class Trainer(object):
self
.
status
[
'save_best_model'
]
=
True
self
.
_eval_with_loader
(
self
.
_eval_loader
)
# restore origin weight on model
if
self
.
use_ema
:
self
.
model
.
set_dict
(
weight
)
def
_eval_with_loader
(
self
,
loader
):
sample_num
=
0
tic
=
time
.
time
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录