Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
ee25d701
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
ee25d701
编写于
4月 25, 2021
作者:
W
wangxinxin08
提交者:
GitHub
4月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
eval with ema weight while training (#2747)
上级
f48cab63
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
19 addition
and
17 deletion
+19
-17
ppdet/engine/callbacks.py
ppdet/engine/callbacks.py
+2
-17
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+17
-0
未找到文件。
ppdet/engine/callbacks.py
浏览文件 @
ee25d701
...
@@ -26,7 +26,6 @@ import paddle
...
@@ -26,7 +26,6 @@ import paddle
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
from
ppdet.utils.checkpoint
import
save_model
from
ppdet.utils.checkpoint
import
save_model
from
ppdet.optimizer
import
ModelEMA
from
ppdet.utils.logger
import
setup_logger
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
'ppdet.engine'
)
logger
=
setup_logger
(
'ppdet.engine'
)
...
@@ -143,20 +142,12 @@ class Checkpointer(Callback):
...
@@ -143,20 +142,12 @@ class Checkpointer(Callback):
super
(
Checkpointer
,
self
).
__init__
(
model
)
super
(
Checkpointer
,
self
).
__init__
(
model
)
cfg
=
self
.
model
.
cfg
cfg
=
self
.
model
.
cfg
self
.
best_ap
=
0.
self
.
best_ap
=
0.
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
self
.
save_dir
=
os
.
path
.
join
(
self
.
model
.
cfg
.
save_dir
,
self
.
save_dir
=
os
.
path
.
join
(
self
.
model
.
cfg
.
save_dir
,
self
.
model
.
cfg
.
filename
)
self
.
model
.
cfg
.
filename
)
if
hasattr
(
self
.
model
.
model
,
'student_model'
):
if
hasattr
(
self
.
model
.
model
,
'student_model'
):
self
.
weight
=
self
.
model
.
model
.
student_model
self
.
weight
=
self
.
model
.
model
.
student_model
else
:
else
:
self
.
weight
=
self
.
model
.
model
self
.
weight
=
self
.
model
.
model
if
self
.
use_ema
:
self
.
ema
=
ModelEMA
(
cfg
[
'ema_decay'
],
self
.
weight
,
use_thres_step
=
True
)
def
on_step_end
(
self
,
status
):
if
self
.
use_ema
:
self
.
ema
.
update
(
self
.
weight
)
def
on_epoch_end
(
self
,
status
):
def
on_epoch_end
(
self
,
status
):
# Checkpointer only performed during training
# Checkpointer only performed during training
...
@@ -170,9 +161,6 @@ class Checkpointer(Callback):
...
@@ -170,9 +161,6 @@ class Checkpointer(Callback):
if
epoch_id
%
self
.
model
.
cfg
.
snapshot_epoch
==
0
or
epoch_id
==
end_epoch
-
1
:
if
epoch_id
%
self
.
model
.
cfg
.
snapshot_epoch
==
0
or
epoch_id
==
end_epoch
-
1
:
save_name
=
str
(
save_name
=
str
(
epoch_id
)
if
epoch_id
!=
end_epoch
-
1
else
"model_final"
epoch_id
)
if
epoch_id
!=
end_epoch
-
1
else
"model_final"
if
self
.
use_ema
:
weight
=
self
.
ema
.
apply
()
else
:
weight
=
self
.
weight
weight
=
self
.
weight
elif
mode
==
'eval'
:
elif
mode
==
'eval'
:
if
'save_best_model'
in
status
and
status
[
'save_best_model'
]:
if
'save_best_model'
in
status
and
status
[
'save_best_model'
]:
...
@@ -187,9 +175,6 @@ class Checkpointer(Callback):
...
@@ -187,9 +175,6 @@ class Checkpointer(Callback):
if
map_res
[
key
][
0
]
>
self
.
best_ap
:
if
map_res
[
key
][
0
]
>
self
.
best_ap
:
self
.
best_ap
=
map_res
[
key
][
0
]
self
.
best_ap
=
map_res
[
key
][
0
]
save_name
=
'best_model'
save_name
=
'best_model'
if
self
.
use_ema
:
weight
=
self
.
ema
.
apply
()
else
:
weight
=
self
.
weight
weight
=
self
.
weight
logger
.
info
(
"Best test {} ap is {:0.3f}."
.
format
(
logger
.
info
(
"Best test {} ap is {:0.3f}."
.
format
(
key
,
self
.
best_ap
))
key
,
self
.
best_ap
))
...
...
ppdet/engine/trainer.py
浏览文件 @
ee25d701
...
@@ -28,6 +28,7 @@ import paddle.distributed as dist
...
@@ -28,6 +28,7 @@ import paddle.distributed as dist
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
from
paddle
import
amp
from
paddle
import
amp
from
paddle.static
import
InputSpec
from
paddle.static
import
InputSpec
from
ppdet.optimizer
import
ModelEMA
from
ppdet.core.workspace
import
create
from
ppdet.core.workspace
import
create
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
...
@@ -61,6 +62,11 @@ class Trainer(object):
...
@@ -61,6 +62,11 @@ class Trainer(object):
self
.
model
=
self
.
cfg
.
model
self
.
model
=
self
.
cfg
.
model
self
.
is_loaded_weights
=
True
self
.
is_loaded_weights
=
True
self
.
use_ema
=
(
'use_ema'
in
cfg
and
cfg
[
'use_ema'
])
if
self
.
use_ema
:
self
.
ema
=
ModelEMA
(
cfg
[
'ema_decay'
],
self
.
model
,
use_thres_step
=
True
)
# build data loader
# build data loader
self
.
dataset
=
cfg
[
'{}Dataset'
.
format
(
self
.
mode
.
capitalize
())]
self
.
dataset
=
cfg
[
'{}Dataset'
.
format
(
self
.
mode
.
capitalize
())]
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
...
@@ -281,8 +287,15 @@ class Trainer(object):
...
@@ -281,8 +287,15 @@ class Trainer(object):
self
.
status
[
'batch_time'
].
update
(
time
.
time
()
-
iter_tic
)
self
.
status
[
'batch_time'
].
update
(
time
.
time
()
-
iter_tic
)
self
.
_compose_callback
.
on_step_end
(
self
.
status
)
self
.
_compose_callback
.
on_step_end
(
self
.
status
)
if
self
.
use_ema
:
self
.
ema
.
update
(
self
.
model
)
iter_tic
=
time
.
time
()
iter_tic
=
time
.
time
()
# apply ema weight on model
if
self
.
use_ema
:
weight
=
self
.
model
.
state_dict
()
self
.
model
.
set_dict
(
self
.
ema
.
apply
())
self
.
_compose_callback
.
on_epoch_end
(
self
.
status
)
self
.
_compose_callback
.
on_epoch_end
(
self
.
status
)
if
validate
and
(
self
.
_nranks
<
2
or
self
.
_local_rank
==
0
)
\
if
validate
and
(
self
.
_nranks
<
2
or
self
.
_local_rank
==
0
)
\
...
@@ -303,6 +316,10 @@ class Trainer(object):
...
@@ -303,6 +316,10 @@ class Trainer(object):
self
.
status
[
'save_best_model'
]
=
True
self
.
status
[
'save_best_model'
]
=
True
self
.
_eval_with_loader
(
self
.
_eval_loader
)
self
.
_eval_with_loader
(
self
.
_eval_loader
)
# restore origin weight on model
if
self
.
use_ema
:
self
.
model
.
set_dict
(
weight
)
def
_eval_with_loader
(
self
,
loader
):
def
_eval_with_loader
(
self
,
loader
):
sample_num
=
0
sample_num
=
0
tic
=
time
.
time
()
tic
=
time
.
time
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录