Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c962e805
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c962e805
编写于
6月 04, 2020
作者:
R
ruri
提交者:
GitHub
6月 04, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #101 from shippingwang/add_efficientnet
add ema and EfficientNet
上级
27052fd0
fdc40d1e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
431 addition
and
26 deletion
+431
-26
.pre-commit-config.yaml
.pre-commit-config.yaml
+0
-3
configs/EfficientNet/EfficientNetB0.yaml
configs/EfficientNet/EfficientNetB0.yaml
+85
-0
ppcls/data/imaug/operators.py
ppcls/data/imaug/operators.py
+21
-2
ppcls/modeling/architectures/efficientnet.py
ppcls/modeling/architectures/efficientnet.py
+5
-3
ppcls/optimizer/learning_rate.py
ppcls/optimizer/learning_rate.py
+69
-10
tools/ema.py
tools/ema.py
+165
-0
tools/ema_clean.py
tools/ema_clean.py
+48
-0
tools/program.py
tools/program.py
+21
-3
tools/train.py
tools/train.py
+17
-5
未找到文件。
.pre-commit-config.yaml
浏览文件 @
c962e805
...
...
@@ -33,6 +33,3 @@
-
id
:
trailing-whitespace
files
:
\.(md|yml)$
-
id
:
check-case-conflict
-
id
:
flake8
args
:
[
'
--ignore=E265'
]
configs/EfficientNet/EfficientNetB0.yaml
0 → 100644
浏览文件 @
c962e805
mode
:
'
train'
ARCHITECTURE
:
name
:
"
EfficientNetB0"
drop_connect_rate
:
0.1
padding_type
:
"
SAME"
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
360
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_ema
:
True
ema_decay
:
0.9999
use_aa
:
True
ls_epsilon
:
0.1
LEARNING_RATE
:
function
:
'
ExponentialWarmup'
params
:
lr
:
0.032
OPTIMIZER
:
function
:
'
RMSProp'
params
:
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
regularizer
:
function
:
'
L2'
factor
:
0.00001
TRAIN
:
batch_size
:
512
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
Fals
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
2
-
RandFlipImage
:
flip_code
:
1
-
AutoAugment
:
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
128
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
interpolation
:
2
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
ppcls/data/imaug/operators.py
浏览文件 @
c962e805
...
...
@@ -25,6 +25,8 @@ import random
import
cv2
import
numpy
as
np
from
.autoaugment
import
ImageNetPolicy
class
OperatorParamError
(
ValueError
):
""" OperatorParamError
...
...
@@ -115,7 +117,9 @@ class CropImage(object):
class
RandCropImage
(
object
):
""" random crop image """
def
__init__
(
self
,
size
,
scale
=
None
,
ratio
=
None
):
def
__init__
(
self
,
size
,
scale
=
None
,
ratio
=
None
,
interpolation
=-
1
):
self
.
interpolation
=
interpolation
if
interpolation
>=
0
else
None
if
type
(
size
)
is
int
:
self
.
size
=
(
size
,
size
)
# (h, w)
else
:
...
...
@@ -149,7 +153,10 @@ class RandCropImage(object):
j
=
random
.
randint
(
0
,
img_h
-
h
)
img
=
img
[
j
:
j
+
h
,
i
:
i
+
w
,
:]
return
cv2
.
resize
(
img
,
size
)
if
self
.
interpolation
is
None
:
return
cv2
.
resize
(
img
,
size
)
else
:
return
cv2
.
resize
(
img
,
size
,
interpolation
=
self
.
interpolation
)
class
RandFlipImage
(
object
):
...
...
@@ -172,6 +179,18 @@ class RandFlipImage(object):
return
img
class
AutoAugment
(
object
):
def
__init__
(
self
):
self
.
policy
=
ImageNetPolicy
()
def
__call__
(
self
,
img
):
from
PIL
import
Image
img
=
np
.
ascontiguousarray
(
img
)
img
=
Image
.
fromarray
(
img
)
img
=
self
.
policy
(
img
)
img
=
np
.
asarray
(
img
)
class
NormalizeImage
(
object
):
""" normalize image such as substract mean, divide std
"""
...
...
ppcls/modeling/architectures/efficientnet.py
浏览文件 @
c962e805
...
...
@@ -383,7 +383,9 @@ class EfficientNet():
use_bias
=
True
,
padding_type
=
self
.
padding_type
,
name
=
name
+
'_se_expand'
)
se_out
=
inputs
*
fluid
.
layers
.
sigmoid
(
x_squeezed
)
#se_out = inputs * fluid.layers.sigmoid(x_squeezed)
se_out
=
fluid
.
layers
.
elementwise_mul
(
inputs
,
fluid
.
layers
.
sigmoid
(
x_squeezed
),
axis
=-
1
)
return
se_out
def
extract_features
(
self
,
inputs
,
is_test
):
...
...
@@ -467,8 +469,8 @@ class BlockDecoder(object):
# Check stride
cond_1
=
(
's'
in
options
and
len
(
options
[
's'
])
==
1
)
cond_2
=
((
len
(
options
[
's'
])
==
2
)
and
(
options
[
's'
][
0
]
==
options
[
's'
][
1
]))
cond_2
=
((
len
(
options
[
's'
])
==
2
)
and
(
options
[
's'
][
0
]
==
options
[
's'
][
1
]))
assert
(
cond_1
or
cond_2
)
return
BlockArgs
(
...
...
ppcls/optimizer/learning_rate.py
浏览文件 @
c962e805
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
Licensed under the Apache License, Version 2.0 (the "License");
#
you may not use this file except in compliance with the License.
#
You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
Unless required by applicable law or agreed to in writing, software
#
distributed under the License is distributed on an "AS IS" BASIS,
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
See the License for the specific language governing permissions and
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -130,7 +130,7 @@ class CosineWarmup(object):
with
fluid
.
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
epoch
<
self
.
warmup_epoch
):
decayed_lr
=
self
.
lr
*
\
(
global_step
/
(
self
.
step_each_epoch
*
self
.
warmup_epoch
))
(
global_step
/
(
self
.
step_each_epoch
*
self
.
warmup_epoch
))
fluid
.
layers
.
tensor
.
assign
(
input
=
decayed_lr
,
output
=
learning_rate
)
with
switch
.
default
():
...
...
@@ -145,6 +145,65 @@ class CosineWarmup(object):
return
learning_rate
class
ExponentialWarmup
(
object
):
"""
Exponential learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): Exponential decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(float): decay epochs
decay_rate(float): decay rate
warmup_epoch(int): epoch num of warmup
"""
def
__init__
(
self
,
lr
,
step_each_epoch
,
decay_epochs
=
2.4
,
decay_rate
=
0.97
,
warmup_epoch
=
5
,
**
kwargs
):
super
(
ExponentialWarmup
,
self
).
__init__
()
self
.
lr
=
lr
self
.
step_each_epoch
=
step_each_epoch
self
.
decay_epochs
=
decay_epochs
*
self
.
step_each_epoch
self
.
decay_rate
=
decay_rate
self
.
warmup_epoch
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
float
(
warmup_epoch
),
dtype
=
'float32'
,
force_cpu
=
True
)
def
__call__
(
self
):
global_step
=
_decay_step_counter
()
learning_rate
=
fluid
.
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
0.0
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"learning_rate"
)
epoch
=
ops
.
floor
(
global_step
/
self
.
step_each_epoch
)
with
fluid
.
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
epoch
<
self
.
warmup_epoch
):
decayed_lr
=
self
.
lr
*
\
(
global_step
/
(
self
.
step_each_epoch
*
self
.
warmup_epoch
))
fluid
.
layers
.
tensor
.
assign
(
input
=
decayed_lr
,
output
=
learning_rate
)
with
switch
.
default
():
rest_step
=
global_step
-
self
.
warmup_epoch
*
self
.
step_each_epoch
div_res
=
ops
.
floor
(
rest_step
/
self
.
decay_epochs
)
decayed_lr
=
self
.
lr
*
(
self
.
decay_rate
**
div_res
)
fluid
.
layers
.
tensor
.
assign
(
input
=
decayed_lr
,
output
=
learning_rate
)
return
learning_rate
class
LearningRateBuilder
():
"""
Build learning rate variable
...
...
tools/ema.py
0 → 100644
浏览文件 @
c962e805
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.fluid.framework
import
Program
,
program_guard
,
name_scope
,
default_main_program
from
paddle.fluid
import
unique_name
,
layers
class
ExponentialMovingAverage
(
object
):
def
__init__
(
self
,
decay
=
0.999
,
thres_steps
=
None
,
zero_debias
=
False
,
name
=
None
):
self
.
_decay
=
decay
self
.
_thres_steps
=
thres_steps
self
.
_name
=
name
if
name
is
not
None
else
''
self
.
_decay_var
=
self
.
_get_ema_decay
()
self
.
_params_tmps
=
[]
for
param
in
default_main_program
().
global_block
().
all_parameters
():
if
param
.
do_model_average
!=
False
:
tmp
=
param
.
block
.
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
(
[
self
.
_name
+
param
.
name
,
'ema_tmp'
])),
dtype
=
param
.
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
self
.
_params_tmps
.
append
((
param
,
tmp
))
self
.
_ema_vars
=
{}
for
param
,
tmp
in
self
.
_params_tmps
:
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
tmp
]),
name_scope
(
'moving_average'
):
self
.
_ema_vars
[
param
.
name
]
=
self
.
_create_ema_vars
(
param
)
self
.
apply_program
=
Program
()
block
=
self
.
apply_program
.
global_block
()
with
program_guard
(
main_program
=
self
.
apply_program
):
decay_pow
=
self
.
_get_decay_pow
(
block
)
for
param
,
tmp
in
self
.
_params_tmps
:
param
=
block
.
_clone_variable
(
param
)
tmp
=
block
.
_clone_variable
(
tmp
)
ema
=
block
.
_clone_variable
(
self
.
_ema_vars
[
param
.
name
])
layers
.
assign
(
input
=
param
,
output
=
tmp
)
# bias correction
if
zero_debias
:
ema
=
ema
/
(
1.0
-
decay_pow
)
layers
.
assign
(
input
=
ema
,
output
=
param
)
self
.
restore_program
=
Program
()
block
=
self
.
restore_program
.
global_block
()
with
program_guard
(
main_program
=
self
.
restore_program
):
for
param
,
tmp
in
self
.
_params_tmps
:
tmp
=
block
.
_clone_variable
(
tmp
)
param
=
block
.
_clone_variable
(
param
)
layers
.
assign
(
input
=
tmp
,
output
=
param
)
def
_get_ema_decay
(
self
):
with
default_main_program
().
_lr_schedule_guard
():
decay_var
=
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
self
.
_decay
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"scheduled_ema_decay_rate"
)
if
self
.
_thres_steps
is
not
None
:
decay_t
=
(
self
.
_thres_steps
+
1.0
)
/
(
self
.
_thres_steps
+
10.0
)
with
layers
.
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
decay_t
<
self
.
_decay
):
layers
.
tensor
.
assign
(
decay_t
,
decay_var
)
with
switch
.
default
():
layers
.
tensor
.
assign
(
np
.
array
(
[
self
.
_decay
],
dtype
=
np
.
float32
),
decay_var
)
return
decay_var
def
_get_decay_pow
(
self
,
block
):
global_steps
=
layers
.
learning_rate_scheduler
.
_decay_step_counter
()
decay_var
=
block
.
_clone_variable
(
self
.
_decay_var
)
decay_pow_acc
=
layers
.
elementwise_pow
(
decay_var
,
global_steps
+
1
)
return
decay_pow_acc
def
_create_ema_vars
(
self
,
param
):
param_ema
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
self
.
_name
+
param
.
name
+
'_ema'
),
shape
=
param
.
shape
,
value
=
0.0
,
dtype
=
param
.
dtype
,
persistable
=
True
)
return
param_ema
def
update
(
self
):
"""
Update Exponential Moving Average. Should only call this method in
train program.
"""
param_master_emas
=
[]
for
param
,
tmp
in
self
.
_params_tmps
:
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
tmp
]),
name_scope
(
'moving_average'
):
param_ema
=
self
.
_ema_vars
[
param
.
name
]
if
param
.
name
+
'.master'
in
self
.
_ema_vars
:
master_ema
=
self
.
_ema_vars
[
param
.
name
+
'.master'
]
param_master_emas
.
append
([
param_ema
,
master_ema
])
else
:
ema_t
=
param_ema
*
self
.
_decay_var
+
param
*
(
1
-
self
.
_decay_var
)
layers
.
assign
(
input
=
ema_t
,
output
=
param_ema
)
# for fp16 params
for
param_ema
,
master_ema
in
param_master_emas
:
default_main_program
().
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
master_ema
},
outputs
=
{
"Out"
:
param_ema
},
attrs
=
{
"in_dtype"
:
master_ema
.
dtype
,
"out_dtype"
:
param_ema
.
dtype
})
@
signature_safe_contextmanager
def
apply
(
self
,
executor
,
need_restore
=
True
):
"""
Apply moving average to parameters for evaluation.
Args:
executor (Executor): The Executor to execute applying.
need_restore (bool): Whether to restore parameters after applying.
"""
executor
.
run
(
self
.
apply_program
)
try
:
yield
finally
:
if
need_restore
:
self
.
restore
(
executor
)
def
restore
(
self
,
executor
):
"""Restore parameters.
Args:
executor (Executor): The Executor to execute restoring.
"""
executor
.
run
(
self
.
restore_program
)
tools/ema_clean.py
0 → 100644
浏览文件 @
c962e805
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
argparse
import
functools
import
shutil
import
sys
def
main
():
"""
Usage: when training with flag use_ema, and evaluating EMA model, should clean the saved model at first.
To generate clean model:
python ema_clean.py ema_model_dir cleaned_model_dir
"""
cleaned_model_dir
=
sys
.
argv
[
1
]
ema_model_dir
=
sys
.
argv
[
2
]
if
not
os
.
path
.
exists
(
cleaned_model_dir
):
os
.
makedirs
(
cleaned_model_dir
)
items
=
os
.
listdir
(
ema_model_dir
)
for
item
in
items
:
if
item
.
find
(
'ema'
)
>
-
1
:
item_clean
=
item
.
replace
(
'_ema_0'
,
''
)
shutil
.
copyfile
(
os
.
path
.
join
(
ema_model_dir
,
item
),
os
.
path
.
join
(
cleaned_model_dir
,
item_clean
))
elif
item
.
find
(
'mean'
)
>
-
1
or
item
.
find
(
'variance'
)
>
-
1
:
shutil
.
copyfile
(
os
.
path
.
join
(
ema_model_dir
,
item
),
os
.
path
.
join
(
cleaned_model_dir
,
item
))
if
__name__
==
'__main__'
:
main
()
tools/program.py
浏览文件 @
c962e805
...
...
@@ -36,6 +36,8 @@ from ppcls.utils import logger
from
paddle.fluid.incubate.fleet.collective
import
fleet
from
paddle.fluid.incubate.fleet.collective
import
DistributedStrategy
from
ema
import
ExponentialMovingAverage
def
create_feeds
(
image_shape
,
use_mix
=
None
):
"""
...
...
@@ -86,7 +88,7 @@ def create_dataloader(feeds):
return
dataloader
def
create_model
(
architecture
,
image
,
classes_num
):
def
create_model
(
architecture
,
image
,
classes_num
,
is_train
):
"""
Create a model
...
...
@@ -101,6 +103,7 @@ def create_model(architecture, image, classes_num):
"""
name
=
architecture
[
"name"
]
params
=
architecture
.
get
(
"params"
,
{})
params
[
'is_test'
]
=
not
is_train
model
=
architectures
.
__dict__
[
name
](
**
params
)
out
=
model
.
net
(
input
=
image
,
class_dim
=
classes_num
)
return
out
...
...
@@ -336,7 +339,7 @@ def build(config, main_prog, startup_prog, is_train=True):
feeds
=
create_feeds
(
config
.
image_shape
,
use_mix
=
use_mix
)
dataloader
=
create_dataloader
(
feeds
.
values
())
out
=
create_model
(
config
.
ARCHITECTURE
,
feeds
[
'image'
],
config
.
classes_num
)
config
.
classes_num
,
is_train
)
fetchs
=
create_fetchs
(
out
,
feeds
,
...
...
@@ -354,6 +357,14 @@ def build(config, main_prog, startup_prog, is_train=True):
optimizer
=
mixed_precision_optimizer
(
config
,
optimizer
)
optimizer
=
dist_optimizer
(
config
,
optimizer
)
optimizer
.
minimize
(
fetchs
[
'loss'
][
0
])
if
config
.
get
(
'use_ema'
):
global_steps
=
fluid
.
layers
.
learning_rate_scheduler
.
_decay_step_counter
(
)
ema
=
ExponentialMovingAverage
(
config
.
get
(
'ema_decay'
),
thres_steps
=
global_steps
)
ema
.
update
()
return
dataloader
,
fetchs
,
ema
return
dataloader
,
fetchs
...
...
@@ -387,7 +398,13 @@ def compile(config, program, loss_name=None):
total_step
=
0
def
run
(
dataloader
,
exe
,
program
,
fetchs
,
epoch
=
0
,
mode
=
'train'
,
vdl_writer
=
None
):
def
run
(
dataloader
,
exe
,
program
,
fetchs
,
epoch
=
0
,
mode
=
'train'
,
vdl_writer
=
None
):
"""
Feed data to the model and fetch the measures and loss
...
...
@@ -401,6 +418,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None
Returns:
"""
print
(
fetchs
)
fetch_list
=
[
f
[
0
]
for
f
in
fetchs
.
values
()]
metric_list
=
[
f
[
1
]
for
f
in
fetchs
.
values
()]
for
m
in
metric_list
:
...
...
tools/train.py
浏览文件 @
c962e805
...
...
@@ -70,8 +70,12 @@ def main(args):
best_top1_acc
=
0.0
# best top1 acc record
train_dataloader
,
train_fetchs
=
program
.
build
(
config
,
train_prog
,
startup_prog
,
is_train
=
True
)
if
not
config
.
get
(
'use_ema'
):
train_dataloader
,
train_fetchs
=
program
.
build
(
config
,
train_prog
,
startup_prog
,
is_train
=
True
)
else
:
train_dataloader
,
train_fetchs
,
ema
=
program
.
build
(
config
,
train_prog
,
startup_prog
,
is_train
=
True
)
if
config
.
validate
:
valid_prog
=
fluid
.
Program
()
...
...
@@ -81,11 +85,11 @@ def main(args):
valid_prog
=
valid_prog
.
clone
(
for_test
=
True
)
# create the "Executor" with the statement of which place
exe
=
fluid
.
Executor
(
place
=
place
)
#
only run startup_prog once to init
exe
=
fluid
.
Executor
(
place
)
#
Parameter initialization
exe
.
run
(
startup_prog
)
# load model from
checkpoint or pretrained model
# load model from
1. checkpoint to resume training, 2. pretrained model to finetune
init_model
(
config
,
train_prog
,
exe
)
train_reader
=
Reader
(
config
,
'train'
)()
...
...
@@ -106,6 +110,14 @@ def main(args):
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
==
0
:
# 2. validate with validate dataset
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
if
config
.
get
(
'use_ema'
):
logger
.
info
(
logger
.
coloring
(
"EMA validate start..."
))
with
train_fetchs
(
'ema'
).
apply
(
exe
):
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
compiled_valid_prog
,
valid_fetchs
,
epoch_id
,
'valid'
)
logger
.
info
(
logger
.
coloring
(
"EMA validate over!"
))
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
compiled_valid_prog
,
valid_fetchs
,
epoch_id
,
'valid'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录