Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
32ce6837
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32ce6837
编写于
5月 06, 2020
作者:
S
shippingwang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ema
上级
bfcd9e4d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
225 addition
and
2 deletion
+225
-2
tools/ema.py
tools/ema.py
+165
-0
tools/ema_clean.py
tools/ema_clean.py
+42
-0
tools/program.py
tools/program.py
+10
-2
tools/train.py
tools/train.py
+8
-0
未找到文件。
tools/ema.py
0 → 100644
浏览文件 @
32ce6837
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.fluid.framework
import
Program
,
program_guard
,
name_scope
,
default_main_program
from
paddle.fluid
import
unique_name
,
layers
class
ExponentialMovingAverage
(
object
):
def
__init__
(
self
,
decay
=
0.999
,
thres_steps
=
None
,
zero_debias
=
False
,
name
=
None
):
self
.
_decay
=
decay
self
.
_thres_steps
=
thres_steps
self
.
_name
=
name
if
name
is
not
None
else
''
self
.
_decay_var
=
self
.
_get_ema_decay
()
self
.
_params_tmps
=
[]
for
param
in
default_main_program
().
global_block
().
all_parameters
():
if
param
.
do_model_average
!=
False
:
tmp
=
param
.
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
[
self
.
_name
+
param
.
name
,
'ema_tmp'
])),
dtype
=
param
.
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
self
.
_params_tmps
.
append
((
param
,
tmp
))
self
.
_ema_vars
=
{}
for
param
,
tmp
in
self
.
_params_tmps
:
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
tmp
]),
name_scope
(
'moving_average'
):
self
.
_ema_vars
[
param
.
name
]
=
self
.
_create_ema_vars
(
param
)
self
.
apply_program
=
Program
()
block
=
self
.
apply_program
.
global_block
()
with
program_guard
(
main_program
=
self
.
apply_program
):
decay_pow
=
self
.
_get_decay_pow
(
block
)
for
param
,
tmp
in
self
.
_params_tmps
:
param
=
block
.
_clone_variable
(
param
)
tmp
=
block
.
_clone_variable
(
tmp
)
ema
=
block
.
_clone_variable
(
self
.
_ema_vars
[
param
.
name
])
layers
.
assign
(
input
=
param
,
output
=
tmp
)
# bias correction
if
zero_debias
:
ema
=
ema
/
(
1.0
-
decay_pow
)
layers
.
assign
(
input
=
ema
,
output
=
param
)
self
.
restore_program
=
Program
()
block
=
self
.
restore_program
.
global_block
()
with
program_guard
(
main_program
=
self
.
restore_program
):
for
param
,
tmp
in
self
.
_params_tmps
:
tmp
=
block
.
_clone_variable
(
tmp
)
param
=
block
.
_clone_variable
(
param
)
layers
.
assign
(
input
=
tmp
,
output
=
param
)
def
_get_ema_decay
(
self
):
with
default_main_program
().
_lr_schedule_guard
():
decay_var
=
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
self
.
_decay
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"scheduled_ema_decay_rate"
)
if
self
.
_thres_steps
is
not
None
:
decay_t
=
(
self
.
_thres_steps
+
1.0
)
/
(
self
.
_thres_steps
+
10.0
)
with
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
decay_t
<
self
.
_decay
):
layers
.
tensor
.
assign
(
decay_t
,
decay_var
)
with
switch
.
default
():
layers
.
tensor
.
assign
(
np
.
array
(
[
self
.
_decay
],
dtype
=
np
.
float32
),
decay_var
)
return
decay_var
def
_get_decay_pow
(
self
,
block
):
global_steps
=
layers
.
learning_rate_scheduler
.
_decay_step_counter
()
decay_var
=
block
.
_clone_variable
(
self
.
_decay_var
)
decay_pow_acc
=
layers
.
elementwise_pow
(
decay_var
,
global_steps
+
1
)
return
decay_pow_acc
def
_create_ema_vars
(
self
,
param
):
param_ema
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
self
.
_name
+
param
.
name
+
'_ema'
),
shape
=
param
.
shape
,
value
=
0.0
,
dtype
=
param
.
dtype
,
persistable
=
True
)
return
param_ema
def
update
(
self
):
"""
Update Exponential Moving Average. Should only call this method in
train program.
"""
param_master_emas
=
[]
for
param
,
tmp
in
self
.
_params_tmps
:
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
tmp
]),
name_scope
(
'moving_average'
):
param_ema
=
self
.
_ema_vars
[
param
.
name
]
if
param
.
name
+
'.master'
in
self
.
_ema_vars
:
master_ema
=
self
.
_ema_vars
[
param
.
name
+
'.master'
]
param_master_emas
.
append
([
param_ema
,
master_ema
])
else
:
ema_t
=
param_ema
*
self
.
_decay_var
+
param
*
(
1
-
self
.
_decay_var
)
layers
.
assign
(
input
=
ema_t
,
output
=
param_ema
)
# for fp16 params
for
param_ema
,
master_ema
in
param_master_emas
:
default_main_program
().
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
master_ema
},
outputs
=
{
"Out"
:
param_ema
},
attrs
=
{
"in_dtype"
:
master_ema
.
dtype
,
"out_dtype"
:
param_ema
.
dtype
})
@
signature_safe_contextmanager
def
apply
(
self
,
executor
,
need_restore
=
True
):
"""
Apply moving average to parameters for evaluation.
Args:
executor (Executor): The Executor to execute applying.
need_restore (bool): Whether to restore parameters after applying.
"""
executor
.
run
(
self
.
apply_program
)
try
:
yield
finally
:
if
need_restore
:
self
.
restore
(
executor
)
def
restore
(
self
,
executor
):
"""Restore parameters.
Args:
executor (Executor): The Executor to execute restoring.
"""
executor
.
run
(
self
.
restore_program
)
tools/ema_clean.py
0 → 100644
浏览文件 @
32ce6837
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
argparse
import
functools
import
shutil
import
sys
def
main
():
cleaned_model_dir
=
sys
.
argv
[
1
]
ema_model_dir
=
sys
.
argv
[
2
]
if
not
os
.
path
.
exists
(
cleaned_model_dir
):
os
.
makedirs
(
cleaned_model_dir
)
items
=
os
.
listdir
(
ema_model_dir
)
for
item
in
items
:
if
item
.
find
(
'ema'
)
>
-
1
:
item_clean
=
item
.
replace
(
'_ema_0'
,
''
)
shutil
.
copyfile
(
os
.
path
.
join
(
ema_model_dir
,
item
),
os
.
path
.
join
(
cleaned_model_dir
,
item_clean
))
elif
item
.
find
(
'mean'
)
>
-
1
or
item
.
find
(
'variance'
)
>
-
1
:
shutil
.
copyfile
(
os
.
path
.
join
(
ema_model_dir
,
item
),
os
.
path
.
join
(
cleaned_model_dir
,
item
))
if
__name__
==
'__main__'
:
main
()
tools/program.py
浏览文件 @
32ce6837
...
...
@@ -86,7 +86,7 @@ def create_dataloader(feeds):
return
dataloader
def
create_model
(
architecture
,
image
,
classes_num
):
def
create_model
(
architecture
,
image
,
classes_num
,
is_train
):
"""
Create a model
...
...
@@ -101,6 +101,8 @@ def create_model(architecture, image, classes_num):
"""
name
=
architecture
[
"name"
]
params
=
architecture
.
get
(
"params"
,
{})
params
[
'is_test'
]
=
not
is_train
print
(
params
)
model
=
architectures
.
__dict__
[
name
](
**
params
)
out
=
model
.
net
(
input
=
image
,
class_dim
=
classes_num
)
return
out
...
...
@@ -323,7 +325,7 @@ def build(config, main_prog, startup_prog, is_train=True):
feeds
=
create_feeds
(
config
.
image_shape
,
use_mix
=
use_mix
)
dataloader
=
create_dataloader
(
feeds
.
values
())
out
=
create_model
(
config
.
ARCHITECTURE
,
feeds
[
'image'
],
config
.
classes_num
)
config
.
classes_num
,
is_train
)
fetchs
=
create_fetchs
(
out
,
feeds
,
...
...
@@ -339,6 +341,12 @@ def build(config, main_prog, startup_prog, is_train=True):
fetchs
[
'lr'
]
=
(
lr
,
AverageMeter
(
'lr'
,
'f'
,
need_avg
=
False
))
optimizer
=
dist_optimizer
(
config
,
optimizer
)
optimizer
.
minimize
(
fetchs
[
'loss'
][
0
])
if
config
.
get
(
'use_ema'
):
global_steps
=
fluid
.
layers
.
learning_rate_scheduler
.
_decay_step_counter
()
ema
=
ExponentialMovingAverage
(
config
.
get
(
'ema_decay'
),
thres_steps
=
global_steps
)
ema
.
update
()
fetchs
[
'ema'
]
=
ema
return
dataloader
,
fetchs
...
...
tools/train.py
浏览文件 @
32ce6837
...
...
@@ -98,6 +98,14 @@ def main(args):
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
==
0
:
# 2. validate with validate dataset
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
if
config
.
get
(
'use_ema'
):
logger
.
info
(
logger
.
coloring
(
"EMA validate start..."
))
with
train_fetchs
(
'ema'
).
apply
(
exe
):
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
compiled_valid_prog
,
valid_fetchs
,
epoch_id
,
'valid'
)
logger
.
info
(
logger
.
coloring
(
"EMA validate over!"
))
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
compiled_valid_prog
,
valid_fetchs
,
epoch_id
,
'valid'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录