Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
ed820223
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed820223
编写于
4月 28, 2022
作者:
F
flytocc
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add EMA code
上级
e1943f9a
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
88 addition
and
46 deletion
+88
-46
ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml
ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml
+5
-0
ppcls/engine/engine.py
ppcls/engine/engine.py
+44
-1
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+3
-0
ppcls/utils/ema.py
ppcls/utils/ema.py
+22
-43
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+14
-2
未找到文件。
ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml
浏览文件 @
ed820223
...
@@ -17,6 +17,11 @@ Global:
...
@@ -17,6 +17,11 @@ Global:
to_static
:
False
to_static
:
False
# model ema
EMA
:
decay
:
0.9999
# model architecture
# model architecture
Arch
:
Arch
:
name
:
ConvNext_tiny
name
:
ConvNext_tiny
...
...
ppcls/engine/engine.py
浏览文件 @
ed820223
...
@@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static
...
@@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static
from
ppcls.loss
import
build_loss
from
ppcls.loss
import
build_loss
from
ppcls.metric
import
build_metrics
from
ppcls.metric
import
build_metrics
from
ppcls.optimizer
import
build_optimizer
from
ppcls.optimizer
import
build_optimizer
from
ppcls.utils.ema
import
ExponentialMovingAverage
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
init_model
from
ppcls.utils.save_load
import
init_model
from
ppcls.utils
import
save_load
from
ppcls.utils
import
save_load
...
@@ -115,6 +116,9 @@ class Engine(object):
...
@@ -115,6 +116,9 @@ class Engine(object):
})
})
paddle
.
fluid
.
set_flags
(
AMP_RELATED_FLAGS_SETTING
)
paddle
.
fluid
.
set_flags
(
AMP_RELATED_FLAGS_SETTING
)
# EMA model
self
.
ema
=
"EMA"
in
self
.
config
and
self
.
mode
==
"train"
if
"class_num"
in
config
[
"Global"
]:
if
"class_num"
in
config
[
"Global"
]:
global_class_num
=
config
[
"Global"
][
"class_num"
]
global_class_num
=
config
[
"Global"
][
"class_num"
]
if
"class_num"
not
in
config
[
"Arch"
]:
if
"class_num"
not
in
config
[
"Arch"
]:
...
@@ -250,6 +254,11 @@ class Engine(object):
...
@@ -250,6 +254,11 @@ class Engine(object):
level
=
amp_level
,
level
=
amp_level
,
save_dtype
=
'float32'
)
save_dtype
=
'float32'
)
# build EMA model
if
self
.
ema
:
self
.
model_ema
=
ExponentialMovingAverage
(
self
.
model
,
self
.
config
[
'EMA'
].
get
(
"decay"
,
0.9999
))
# for distributed
# for distributed
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
self
.
config
[
"Global"
][
"distributed"
]
=
world_size
!=
1
self
.
config
[
"Global"
][
"distributed"
]
=
world_size
!=
1
...
@@ -278,6 +287,10 @@ class Engine(object):
...
@@ -278,6 +287,10 @@ class Engine(object):
"metric"
:
0.0
,
"metric"
:
0.0
,
"epoch"
:
0
,
"epoch"
:
0
,
}
}
ema_module
=
None
if
self
.
ema
:
best_metric_ema
=
0.0
ema_module
=
self
.
model_ema
.
module
# key:
# key:
# val: metrics list word
# val: metrics list word
self
.
output_info
=
dict
()
self
.
output_info
=
dict
()
...
@@ -292,7 +305,8 @@ class Engine(object):
...
@@ -292,7 +305,8 @@ class Engine(object):
if
self
.
config
.
Global
.
checkpoints
is
not
None
:
if
self
.
config
.
Global
.
checkpoints
is
not
None
:
metric_info
=
init_model
(
self
.
config
.
Global
,
self
.
model
,
metric_info
=
init_model
(
self
.
config
.
Global
,
self
.
model
,
self
.
optimizer
,
self
.
train_loss_func
)
self
.
optimizer
,
self
.
train_loss_func
,
ema_module
)
if
metric_info
is
not
None
:
if
metric_info
is
not
None
:
best_metric
.
update
(
metric_info
)
best_metric
.
update
(
metric_info
)
...
@@ -327,6 +341,7 @@ class Engine(object):
...
@@ -327,6 +341,7 @@ class Engine(object):
self
.
optimizer
,
self
.
optimizer
,
best_metric
,
best_metric
,
self
.
output_dir
,
self
.
output_dir
,
ema
=
ema_module
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
,
prefix
=
"best_model"
,
loss
=
self
.
train_loss_func
)
loss
=
self
.
train_loss_func
)
...
@@ -340,6 +355,32 @@ class Engine(object):
...
@@ -340,6 +355,32 @@ class Engine(object):
self
.
model
.
train
()
self
.
model
.
train
()
if
self
.
ema
:
ori_model
,
self
.
model
=
self
.
model
,
ema_module
acc_ema
=
self
.
eval
(
epoch_id
)
self
.
model
=
ori_model
ema_module
.
eval
()
if
acc_ema
>
best_metric_ema
:
best_metric_ema
=
acc_ema
save_load
.
save_model
(
self
.
model
,
self
.
optimizer
,
{
"metric"
:
acc_ema
,
"epoch"
:
epoch_id
},
self
.
output_dir
,
ema
=
ema_module
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model_ema"
,
loss
=
self
.
train_loss_func
)
logger
.
info
(
"[Eval][Epoch {}][best metric ema: {}]"
.
format
(
epoch_id
,
best_metric_ema
))
logger
.
scaler
(
name
=
"eval_acc_ema"
,
value
=
acc_ema
,
step
=
epoch_id
,
writer
=
self
.
vdl_writer
)
# save model
# save model
if
epoch_id
%
save_interval
==
0
:
if
epoch_id
%
save_interval
==
0
:
save_load
.
save_model
(
save_load
.
save_model
(
...
@@ -347,6 +388,7 @@ class Engine(object):
...
@@ -347,6 +388,7 @@ class Engine(object):
self
.
optimizer
,
{
"metric"
:
acc
,
self
.
optimizer
,
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
"epoch"
:
epoch_id
},
self
.
output_dir
,
self
.
output_dir
,
ema
=
ema_module
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"epoch_{}"
.
format
(
epoch_id
),
prefix
=
"epoch_{}"
.
format
(
epoch_id
),
loss
=
self
.
train_loss_func
)
loss
=
self
.
train_loss_func
)
...
@@ -356,6 +398,7 @@ class Engine(object):
...
@@ -356,6 +398,7 @@ class Engine(object):
self
.
optimizer
,
{
"metric"
:
acc
,
self
.
optimizer
,
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
"epoch"
:
epoch_id
},
self
.
output_dir
,
self
.
output_dir
,
ema
=
ema_module
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"latest"
,
prefix
=
"latest"
,
loss
=
self
.
train_loss_func
)
loss
=
self
.
train_loss_func
)
...
...
ppcls/engine/train/train.py
浏览文件 @
ed820223
...
@@ -69,6 +69,9 @@ def train_epoch(engine, epoch_id, print_batch_step):
...
@@ -69,6 +69,9 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr
# step lr
for
i
in
range
(
len
(
engine
.
lr_sch
)):
for
i
in
range
(
len
(
engine
.
lr_sch
)):
engine
.
lr_sch
[
i
].
step
()
engine
.
lr_sch
[
i
].
step
()
# update ema
if
engine
.
ema
:
engine
.
model_ema
.
update
(
engine
.
model
)
# below code just for logging
# below code just for logging
# update metric_for_logger
# update metric_for_logger
...
...
ppcls/utils/ema.py
浏览文件 @
ed820223
# Copyright (c) 202
0 PaddlePaddle Authors. All Rights Reserve
.
# Copyright (c) 202
1 PaddlePaddle Authors. All Rights Reserved
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,52 +12,31 @@
...
@@ -12,52 +12,31 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
copy
import
deepcopy
import
paddle
import
paddle
import
numpy
as
np
class
ExponentialMovingAverage
():
class
ExponentialMovingAverage
():
"""
"""
Exponential Moving Average
Exponential Moving Average
Code was heavily based on https://github.com/
Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/
ema.py
Code was heavily based on https://github.com/
rwightman/pytorch-image-models/blob/master/timm/utils/model_
ema.py
"""
"""
def
__init__
(
self
,
model
,
decay
,
thres_steps
=
True
):
def
__init__
(
self
,
model
,
decay
=
0.9999
):
self
.
_model
=
model
super
().
__init__
()
self
.
_decay
=
decay
# make a copy of the model for accumulating moving average of weights
self
.
_thres_steps
=
thres_steps
self
.
module
=
deepcopy
(
model
)
self
.
_shadow
=
{}
self
.
module
.
eval
()
self
.
_backup
=
{}
self
.
decay
=
decay
def
register
(
self
):
@
paddle
.
no_grad
()
self
.
_update_step
=
0
def
_update
(
self
,
model
,
update_fn
):
for
name
,
param
in
self
.
_model
.
named_parameters
():
for
ema_v
,
model_v
in
zip
(
self
.
module
.
state_dict
().
values
(),
model
.
state_dict
().
values
()):
if
param
.
stop_gradient
is
False
:
ema_v
.
set_value
(
update_fn
(
ema_v
,
model_v
))
self
.
_shadow
[
name
]
=
param
.
numpy
().
copy
()
def
update
(
self
,
model
):
def
update
(
self
):
self
.
_update
(
model
,
update_fn
=
lambda
e
,
m
:
self
.
decay
*
e
+
(
1.
-
self
.
decay
)
*
m
)
decay
=
min
(
self
.
_decay
,
(
1
+
self
.
_update_step
)
/
(
10
+
self
.
_update_step
))
if
self
.
_thres_steps
else
self
.
_decay
def
set
(
self
,
model
):
for
name
,
param
in
self
.
_model
.
named_parameters
():
self
.
_update
(
model
,
update_fn
=
lambda
e
,
m
:
m
)
if
param
.
stop_gradient
is
False
:
assert
name
in
self
.
_shadow
new_val
=
np
.
array
(
param
.
numpy
().
copy
())
old_val
=
np
.
array
(
self
.
_shadow
[
name
])
new_average
=
decay
*
old_val
+
(
1
-
decay
)
*
new_val
self
.
_shadow
[
name
]
=
new_average
self
.
_update_step
+=
1
return
decay
def
apply
(
self
):
for
name
,
param
in
self
.
_model
.
named_parameters
():
if
param
.
stop_gradient
is
False
:
assert
name
in
self
.
_shadow
self
.
_backup
[
name
]
=
np
.
array
(
param
.
numpy
().
copy
())
param
.
set_value
(
np
.
array
(
self
.
_shadow
[
name
]))
def
restore
(
self
):
for
name
,
param
in
self
.
_model
.
named_parameters
():
if
param
.
stop_gradient
is
False
:
assert
name
in
self
.
_backup
param
.
set_value
(
self
.
_backup
[
name
])
self
.
_backup
=
{}
ppcls/utils/save_load.py
浏览文件 @
ed820223
...
@@ -87,7 +87,11 @@ def load_distillation_model(model, pretrained_model):
...
@@ -87,7 +87,11 @@ def load_distillation_model(model, pretrained_model):
pretrained_model
))
pretrained_model
))
def
init_model
(
config
,
net
,
optimizer
=
None
,
loss
:
paddle
.
nn
.
Layer
=
None
):
def
init_model
(
config
,
net
,
optimizer
=
None
,
loss
:
paddle
.
nn
.
Layer
=
None
,
ema
=
None
):
"""
"""
load model from checkpoint or pretrained_model
load model from checkpoint or pretrained_model
"""
"""
...
@@ -105,7 +109,12 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
...
@@ -105,7 +109,12 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
net
.
set_state_dict
(
para_dict
)
net
.
set_state_dict
(
para_dict
)
loss
.
set_state_dict
(
para_dict
)
loss
.
set_state_dict
(
para_dict
)
for
i
in
range
(
len
(
optimizer
)):
for
i
in
range
(
len
(
optimizer
)):
optimizer
[
i
].
set_state_dict
(
opti_dict
)
optimizer
[
i
].
set_state_dict
(
opti_dict
[
i
])
if
ema
is
not
None
:
assert
os
.
path
.
exists
(
checkpoints
+
".ema.pdparams"
),
\
"Given dir {}.ema.pdparams not exist."
.
format
(
checkpoints
)
para_ema_dict
=
paddle
.
load
(
checkpoints
+
".ema.pdparams"
)
ema
.
set_state_dict
(
para_ema_dict
)
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
return
metric_dict
return
metric_dict
...
@@ -125,6 +134,7 @@ def save_model(net,
...
@@ -125,6 +134,7 @@ def save_model(net,
optimizer
,
optimizer
,
metric_info
,
metric_info
,
model_path
,
model_path
,
ema
=
None
,
model_name
=
""
,
model_name
=
""
,
prefix
=
'ppcls'
,
prefix
=
'ppcls'
,
loss
:
paddle
.
nn
.
Layer
=
None
):
loss
:
paddle
.
nn
.
Layer
=
None
):
...
@@ -145,6 +155,8 @@ def save_model(net,
...
@@ -145,6 +155,8 @@ def save_model(net,
params_state_dict
.
update
(
loss_state_dict
)
params_state_dict
.
update
(
loss_state_dict
)
paddle
.
save
(
params_state_dict
,
model_path
+
".pdparams"
)
paddle
.
save
(
params_state_dict
,
model_path
+
".pdparams"
)
if
ema
is
not
None
:
paddle
.
save
(
ema
.
state_dict
(),
model_path
+
".ema.pdparams"
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
model_path
+
".pdopt"
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
model_path
+
".pdopt"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录