Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
6a5f4626
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6a5f4626
编写于
11月 18, 2020
作者:
L
littletomatodonkey
提交者:
GitHub
11月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add static running in dygraph (#399)
* add static running in dygraph
上级
834cc164
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
651 addition
and
10 deletion
+651
-10
ppcls/data/imaug/batch_operators.py
ppcls/data/imaug/batch_operators.py
+4
-2
ppcls/modeling/architectures/efficientnet.py
ppcls/modeling/architectures/efficientnet.py
+8
-7
ppcls/optimizer/optimizer.py
ppcls/optimizer/optimizer.py
+1
-1
tools/static/program.py
tools/static/program.py
+496
-0
tools/static/train.py
tools/static/train.py
+142
-0
未找到文件。
ppcls/data/imaug/batch_operators.py
浏览文件 @
6a5f4626
...
...
@@ -53,8 +53,9 @@ class MixupOperator(BatchOperator):
imgs
,
labels
,
bs
=
self
.
_unpack
(
batch
)
idx
=
np
.
random
.
permutation
(
bs
)
lam
=
np
.
random
.
beta
(
self
.
_alpha
,
self
.
_alpha
)
lams
=
np
.
array
([
lam
]
*
bs
,
dtype
=
np
.
float32
)
imgs
=
lam
*
imgs
+
(
1
-
lam
)
*
imgs
[
idx
]
return
list
(
zip
(
imgs
,
labels
,
labels
[
idx
],
[
lam
]
*
b
s
))
return
list
(
zip
(
imgs
,
labels
,
labels
[
idx
],
lam
s
))
class
CutmixOperator
(
BatchOperator
):
...
...
@@ -93,7 +94,8 @@ class CutmixOperator(BatchOperator):
imgs
[:,
:,
bbx1
:
bbx2
,
bby1
:
bby2
]
=
imgs
[
idx
,
:,
bbx1
:
bbx2
,
bby1
:
bby2
]
lam
=
1
-
(
float
(
bbx2
-
bbx1
)
*
(
bby2
-
bby1
)
/
(
imgs
.
shape
[
-
2
]
*
imgs
.
shape
[
-
1
]))
return
list
(
zip
(
imgs
,
labels
,
labels
[
idx
],
[
lam
]
*
bs
))
lams
=
np
.
array
([
lam
]
*
bs
,
dtype
=
np
.
float32
)
return
list
(
zip
(
imgs
,
labels
,
labels
[
idx
],
lams
))
class
FmixOperator
(
BatchOperator
):
...
...
ppcls/modeling/architectures/efficientnet.py
浏览文件 @
6a5f4626
...
...
@@ -242,13 +242,14 @@ inp_shape = {
def
_drop_connect
(
inputs
,
prob
,
is_test
):
if
is_test
:
return
inputs
keep_prob
=
1.0
-
prob
inputs_shape
=
paddle
.
shape
(
inputs
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
=
[
inputs_shape
[
0
],
1
,
1
,
1
])
binary_tensor
=
paddle
.
floor
(
random_tensor
)
output
=
paddle
.
multiply
(
inputs
,
binary_tensor
)
/
keep_prob
output
=
inputs
else
:
keep_prob
=
1.0
-
prob
inputs_shape
=
paddle
.
shape
(
inputs
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
=
[
inputs_shape
[
0
],
1
,
1
,
1
])
binary_tensor
=
paddle
.
floor
(
random_tensor
)
output
=
paddle
.
multiply
(
inputs
,
binary_tensor
)
/
keep_prob
return
output
...
...
ppcls/optimizer/optimizer.py
浏览文件 @
6a5f4626
...
...
@@ -154,7 +154,7 @@ class OptimizerBuilder(object):
reg
=
getattr
(
mod
,
reg_func
)(
**
regularizer
)()
self
.
params
[
'regularization'
]
=
reg
def
__call__
(
self
,
learning_rate
,
parameter_list
):
def
__call__
(
self
,
learning_rate
,
parameter_list
=
None
):
mod
=
sys
.
modules
[
__name__
]
opt
=
getattr
(
mod
,
self
.
function
)
return
opt
(
learning_rate
=
learning_rate
,
...
...
tools/static/program.py
0 → 100644
浏览文件 @
6a5f4626
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
numpy
as
np
from
collections
import
OrderedDict
import
paddle
import
paddle.nn.functional
as
F
from
ppcls.optimizer.learning_rate
import
LearningRateBuilder
from
ppcls.optimizer.optimizer
import
OptimizerBuilder
from
ppcls.modeling
import
architectures
from
ppcls.modeling.loss
import
CELoss
from
ppcls.modeling.loss
import
MixCELoss
from
ppcls.modeling.loss
import
JSDivLoss
from
ppcls.modeling.loss
import
GoogLeNetLoss
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
from
paddle.distributed
import
fleet
from
paddle.distributed.fleet
import
DistributedStrategy
def
_mkdir_if_not_exist
(
path
):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if
not
os
.
path
.
exists
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
logger
.
warning
(
'be happy if some process has already created {}'
.
format
(
path
))
else
:
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
save_model
(
program
,
model_path
,
epoch_id
,
prefix
=
'ppcls'
):
"""
save model to the target path
"""
model_path
=
os
.
path
.
join
(
model_path
,
str
(
epoch_id
))
_mkdir_if_not_exist
(
model_path
)
model_prefix
=
os
.
path
.
join
(
model_path
,
prefix
)
paddle
.
static
.
save
(
program
,
model_prefix
)
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
def
create_feeds
(
image_shape
,
use_mix
=
None
):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
feeds(dict): dict of model input variables
"""
feeds
=
OrderedDict
()
feeds
[
'image'
]
=
paddle
.
static
.
data
(
name
=
"feed_image"
,
shape
=
[
None
]
+
image_shape
,
dtype
=
"float32"
)
if
use_mix
:
feeds
[
'feed_y_a'
]
=
paddle
.
static
.
data
(
name
=
"feed_y_a"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
feeds
[
'feed_y_b'
]
=
paddle
.
static
.
data
(
name
=
"feed_y_b"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
feeds
[
'feed_lam'
]
=
paddle
.
static
.
data
(
name
=
"feed_lam"
,
shape
=
[
None
,
1
],
dtype
=
"float32"
)
else
:
feeds
[
'label'
]
=
paddle
.
static
.
data
(
name
=
"feed_label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
return
feeds
def
create_model
(
architecture
,
image
,
classes_num
,
is_train
):
"""
Create a model
Args:
architecture(dict): architecture information,
name(such as ResNet50) is needed
image(variable): model input variable
classes_num(int): num of classes
Returns:
out(variable): model output variable
"""
name
=
architecture
[
"name"
]
params
=
architecture
.
get
(
"params"
,
{})
if
"is_test"
in
params
:
params
[
'is_test'
]
=
not
is_train
model
=
architectures
.
__dict__
[
name
](
class_dim
=
classes_num
,
**
params
)
out
=
model
(
image
)
return
out
def
create_loss
(
out
,
feeds
,
architecture
,
classes_num
=
1000
,
epsilon
=
None
,
use_mix
=
False
,
use_distillation
=
False
):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
2. CrossEnotry loss with label smoothing
3. CrossEnotry loss with mix(mixup, cutmix, fmix)
4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
5. GoogLeNet loss
Args:
out(variable): model output variable
feeds(dict): dict of model input variables
architecture(dict): architecture information,
name(such as ResNet50) is needed
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
loss(variable): loss variable
"""
if
use_mix
:
feed_y_a
=
paddle
.
reshape
(
feeds
[
'feed_y_a'
],
[
-
1
,
1
])
feed_y_b
=
paddle
.
reshape
(
feeds
[
'feed_y_b'
],
[
-
1
,
1
])
feed_lam
=
paddle
.
reshape
(
feeds
[
'feed_lam'
],
[
-
1
,
1
])
else
:
target
=
paddle
.
reshape
(
feeds
[
'label'
],
[
-
1
,
1
])
if
architecture
[
"name"
]
==
"GoogLeNet"
:
assert
len
(
out
)
==
3
,
"GoogLeNet should have 3 outputs"
loss
=
GoogLeNetLoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
[
0
],
out
[
1
],
out
[
2
],
target
)
if
use_distillation
:
assert
len
(
out
)
==
2
,
(
"distillation output length must be 2, "
"but got {}"
.
format
(
len
(
out
)))
loss
=
JSDivLoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
[
1
],
out
[
0
])
if
use_mix
:
loss
=
MixCELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
,
feed_y_a
,
feed_y_b
,
feed_lam
)
else
:
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
,
target
)
def
create_metric
(
out
,
feeds
,
architecture
,
topk
=
5
,
classes_num
=
1000
,
use_distillation
=
False
):
"""
Create measures of model accuracy, such as top1 and top5
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
topk(int): usually top5
classes_num(int): num of classes
Returns:
fetchs(dict): dict of measures
"""
label
=
paddle
.
reshape
(
feeds
[
'label'
],
[
-
1
,
1
])
if
architecture
[
"name"
]
==
"GoogLeNet"
:
assert
len
(
out
)
==
3
,
"GoogLeNet should have 3 outputs"
out
=
out
[
0
]
else
:
# just need student label to get metrics
if
use_distillation
:
out
=
out
[
1
]
softmax_out
=
F
.
softmax
(
out
)
fetchs
=
OrderedDict
()
# set top1 to fetchs
top1
=
paddle
.
metric
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
1
)
fetchs
[
'top1'
]
=
(
top1
,
AverageMeter
(
'top1'
,
'.4f'
,
need_avg
=
True
))
# set topk to fetchs
k
=
min
(
topk
,
classes_num
)
topk
=
paddle
.
metric
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
k
)
topk_name
=
'top{}'
.
format
(
k
)
fetchs
[
topk_name
]
=
(
topk
,
AverageMeter
(
topk_name
,
'.4f'
,
need_avg
=
True
))
return
fetchs
def
create_fetchs
(
out
,
feeds
,
architecture
,
topk
=
5
,
classes_num
=
1000
,
epsilon
=
None
,
use_mix
=
False
,
use_distillation
=
False
):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if use_mix).
Args:
out(variable): model output variable
feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information,
name(such as ResNet50) is needed
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
fetchs
=
OrderedDict
()
loss
=
create_loss
(
out
,
feeds
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
use_distillation
)
fetchs
[
'loss'
]
=
(
loss
,
AverageMeter
(
'loss'
,
'7.4f'
,
need_avg
=
True
))
if
not
use_mix
:
metric
=
create_metric
(
out
,
feeds
,
architecture
,
topk
,
classes_num
,
use_distillation
)
fetchs
.
update
(
metric
)
return
fetchs
def
create_optimizer
(
config
):
"""
Create an optimizer using config, usually including
learning rate and regularization.
Args:
config(dict): such as
{
'LEARNING_RATE':
{'function': 'Cosine',
'params': {'lr': 0.1}
},
'OPTIMIZER':
{'function': 'Momentum',
'params':{'momentum': 0.9},
'regularizer':
{'function': 'L2', 'factor': 0.0001}
}
}
Returns:
an optimizer instance
"""
# create learning_rate instance
lr_config
=
config
[
'LEARNING_RATE'
]
lr_config
[
'params'
].
update
({
'epochs'
:
config
[
'epochs'
],
'step_each_epoch'
:
config
[
'total_images'
]
//
config
[
'TRAIN'
][
'batch_size'
],
})
lr
=
LearningRateBuilder
(
**
lr_config
)()
# create optimizer instance
opt_config
=
config
[
'OPTIMIZER'
]
opt
=
OptimizerBuilder
(
**
opt_config
)
return
opt
(
lr
),
lr
def
dist_optimizer
(
config
,
optimizer
):
"""
Create a distributed optimizer based on a normal optimizer
Args:
config(dict):
optimizer(): a normal optimizer
Returns:
optimizer: a distributed optimizer
"""
exec_strategy
=
paddle
.
static
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
3
exec_strategy
.
num_iteration_per_drop_scope
=
10
dist_strategy
=
DistributedStrategy
()
dist_strategy
.
nccl_comm_num
=
1
dist_strategy
.
fuse_all_reduce_ops
=
True
dist_strategy
.
execution_strategy
=
exec_strategy
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
dist_strategy
)
return
optimizer
def
mixed_precision_optimizer
(
config
,
optimizer
):
use_fp16
=
config
.
get
(
'use_fp16'
,
False
)
amp_scale_loss
=
config
.
get
(
'amp_scale_loss'
,
1.0
)
use_dynamic_loss_scaling
=
config
.
get
(
'use_dynamic_loss_scaling'
,
False
)
if
use_fp16
:
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
,
init_loss_scaling
=
amp_scale_loss
,
use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
)
return
optimizer
def
build
(
config
,
main_prog
,
startup_prog
,
is_train
=
True
,
is_distributed
=
True
):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
is_distributed(bool): whether to use distributed training method
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
"""
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
with
paddle
.
utils
.
unique_name
.
guard
():
use_mix
=
config
.
get
(
'use_mix'
)
and
is_train
use_distillation
=
config
.
get
(
'use_distillation'
)
feeds
=
create_feeds
(
config
.
image_shape
,
use_mix
=
use_mix
)
out
=
create_model
(
config
.
ARCHITECTURE
,
feeds
[
'image'
],
config
.
classes_num
,
is_train
)
fetchs
=
create_fetchs
(
out
,
feeds
,
config
.
ARCHITECTURE
,
config
.
topk
,
config
.
classes_num
,
epsilon
=
config
.
get
(
'ls_epsilon'
),
use_mix
=
use_mix
,
use_distillation
=
use_distillation
)
lr_scheduler
=
None
if
is_train
:
optimizer
,
lr_scheduler
=
create_optimizer
(
config
)
optimizer
=
mixed_precision_optimizer
(
config
,
optimizer
)
if
is_distributed
:
optimizer
=
dist_optimizer
(
config
,
optimizer
)
optimizer
.
minimize
(
fetchs
[
'loss'
][
0
])
return
fetchs
,
lr_scheduler
,
feeds
def
compile
(
config
,
program
,
loss_name
=
None
,
share_prog
=
None
):
"""
Compile the program
Args:
config(dict): config
program(): the program which is wrapped by
loss_name(str): loss name
share_prog(): the shared program, used for evaluation during training
Returns:
compiled_program(): a compiled program
"""
build_strategy
=
paddle
.
static
.
BuildStrategy
()
exec_strategy
=
paddle
.
static
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
1
exec_strategy
.
num_iteration_per_drop_scope
=
10
compiled_program
=
paddle
.
static
.
CompiledProgram
(
program
).
with_data_parallel
(
share_vars_from
=
share_prog
,
loss_name
=
loss_name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
return
compiled_program
total_step
=
0
def
run
(
dataloader
,
exe
,
program
,
feeds
,
fetchs
,
epoch
=
0
,
mode
=
'train'
,
config
=
None
,
vdl_writer
=
None
,
lr_scheduler
=
None
):
"""
Feed data to the model and fetch the measures and loss
Args:
dataloader(paddle io dataloader):
exe():
program():
fetchs(dict): dict of measures and the loss
epoch(int): epoch of training or validation
model(str): log only
Returns:
"""
fetch_list
=
[
f
[
0
]
for
f
in
fetchs
.
values
()]
metric_list
=
[
f
[
1
]
for
f
in
fetchs
.
values
()]
if
mode
==
"train"
:
metric_list
.
append
(
AverageMeter
(
'lr'
,
'f'
,
need_avg
=
False
))
for
m
in
metric_list
:
m
.
reset
()
batch_time
=
AverageMeter
(
'elapse'
,
'.3f'
)
tic
=
time
.
time
()
for
idx
,
batch
in
enumerate
(
dataloader
()):
batch_size
=
batch
[
0
].
shape
()[
0
]
feed_dict
=
{
key
.
name
:
batch
[
idx
]
for
idx
,
key
in
enumerate
(
feeds
.
values
())}
metrics
=
exe
.
run
(
program
=
program
,
feed
=
feed_dict
,
fetch_list
=
fetch_list
)
batch_time
.
update
(
time
.
time
()
-
tic
)
tic
=
time
.
time
()
for
i
,
m
in
enumerate
(
metrics
):
metric_list
[
i
].
update
(
np
.
mean
(
m
),
batch_size
)
if
mode
==
"train"
:
metric_list
[
-
1
].
update
(
lr_scheduler
.
get_lr
())
fetchs_str
=
''
.
join
([
str
(
m
.
value
)
+
' '
for
m
in
metric_list
]
+
[
batch_time
.
value
])
+
's'
if
lr_scheduler
is
not
None
:
if
lr_scheduler
.
update_specified
:
curr_global_counter
=
lr_scheduler
.
step_each_epoch
*
epoch
+
idx
update
=
max
(
0
,
curr_global_counter
-
lr_scheduler
.
update_start_step
)
%
lr_scheduler
.
update_step_interval
==
0
if
update
:
lr_scheduler
.
step
()
else
:
lr_scheduler
.
step
()
if
vdl_writer
:
global
total_step
logger
.
scaler
(
'loss'
,
metrics
[
0
][
0
],
total_step
,
vdl_writer
)
total_step
+=
1
if
mode
==
'eval'
:
if
idx
%
config
.
get
(
'print_interval'
,
10
)
==
0
:
logger
.
info
(
"{:s} step:{:<4d} {:s}"
.
format
(
mode
,
idx
,
fetchs_str
))
else
:
epoch_str
=
"epoch:{:<3d}"
.
format
(
epoch
)
step_str
=
"{:s} step:{:<4d}"
.
format
(
mode
,
idx
)
if
idx
%
config
.
get
(
'print_interval'
,
10
)
==
0
:
logger
.
info
(
"{:s} {:s} {:s}"
.
format
(
logger
.
coloring
(
epoch_str
,
"HEADER"
)
if
idx
==
0
else
epoch_str
,
logger
.
coloring
(
step_str
,
"PURPLE"
),
logger
.
coloring
(
fetchs_str
,
'OKGREEN'
)))
end_str
=
''
.
join
([
str
(
m
.
mean
)
+
' '
for
m
in
metric_list
]
+
[
batch_time
.
total
])
+
's'
if
mode
==
'eval'
:
logger
.
info
(
"END {:s} {:s}s"
.
format
(
mode
,
end_str
))
else
:
end_epoch_str
=
"END epoch:{:<3d}"
.
format
(
epoch
)
logger
.
info
(
"{:s} {:s} {:s}"
.
format
(
logger
.
coloring
(
end_epoch_str
,
"RED"
),
logger
.
coloring
(
mode
,
"PURPLE"
),
logger
.
coloring
(
end_str
,
"OKGREEN"
)))
# return top1_acc in order to save the best model
if
mode
==
'valid'
:
return
fetchs
[
"top1"
][
1
].
avg
tools/static/train.py
0 → 100644
浏览文件 @
6a5f4626
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../../'
)))
from
sys
import
version_info
import
paddle
from
paddle.distributed
import
ParallelEnv
from
paddle.distributed
import
fleet
from
ppcls.data
import
Reader
from
ppcls.utils.config
import
get_config
from
ppcls.utils
import
logger
from
tools.static
import
program
from
program
import
save_model
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"PaddleClas train script"
)
parser
.
add_argument
(
'-c'
,
'--config'
,
type
=
str
,
default
=
'configs/ResNet/ResNet50.yaml'
,
help
=
'config file path'
)
parser
.
add_argument
(
'--vdl_dir'
,
type
=
str
,
default
=
None
,
help
=
'VisualDL logging directory for image.'
)
parser
.
add_argument
(
'-o'
,
'--override'
,
action
=
'append'
,
default
=
[],
help
=
'config options to be overridden'
)
args
=
parser
.
parse_args
()
return
args
def
main
(
args
):
fleet
.
init
(
is_collective
=
True
)
config
=
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
# assign the place
use_gpu
=
config
.
get
(
"use_gpu"
,
True
)
assert
use_gpu
is
True
,
"gpu must be true in static mode!"
place
=
'gpu:{}'
.
format
(
ParallelEnv
().
dev_id
)
place
=
paddle
.
set_device
(
place
)
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
startup_prog
=
paddle
.
static
.
Program
()
train_prog
=
paddle
.
static
.
Program
()
best_top1_acc
=
0.0
# best top1 acc record
train_fetchs
,
lr_scheduler
,
train_feeds
=
program
.
build
(
config
,
train_prog
,
startup_prog
,
is_train
=
True
)
if
config
.
validate
:
valid_prog
=
paddle
.
static
.
Program
()
valid_fetchs
,
_
,
valid_feeds
=
program
.
build
(
config
,
valid_prog
,
startup_prog
,
is_train
=
False
)
# clone to prune some content which is irrelevant in valid_prog
valid_prog
=
valid_prog
.
clone
(
for_test
=
True
)
# create the "Executor" with the statement of which place
exe
=
paddle
.
static
.
Executor
(
place
)
# Parameter initialization
exe
.
run
(
startup_prog
)
# load model from 1. checkpoint to resume training, 2. pretrained model to finetune
train_dataloader
=
Reader
(
config
,
'train'
,
places
=
place
)()
if
config
.
validate
and
ParallelEnv
().
local_rank
==
0
:
valid_dataloader
=
Reader
(
config
,
'valid'
,
places
=
place
)()
compiled_valid_prog
=
program
.
compile
(
config
,
valid_prog
)
vdl_writer
=
None
if
args
.
vdl_dir
:
if
version_info
.
major
==
2
:
logger
.
info
(
"visualdl is just supported for python3, so it is disabled in python2..."
)
else
:
from
visualdl
import
LogWriter
vdl_writer
=
LogWriter
(
args
.
vdl_dir
)
for
epoch_id
in
range
(
config
.
epochs
):
# 1. train with train dataset
program
.
run
(
train_dataloader
,
exe
,
train_prog
,
train_feeds
,
train_fetchs
,
epoch_id
,
'train'
,
config
,
vdl_writer
,
lr_scheduler
)
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
==
0
:
# 2. validate with validate dataset
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
compiled_valid_prog
,
valid_feeds
,
valid_fetchs
,
epoch_id
,
'valid'
,
config
)
if
top1_acc
>
best_top1_acc
:
best_top1_acc
=
top1_acc
message
=
"The best top1 acc {:.5f}, in epoch: {:d}"
.
format
(
best_top1_acc
,
epoch_id
)
logger
.
info
(
"{:s}"
.
format
(
logger
.
coloring
(
message
,
"RED"
)))
if
epoch_id
%
config
.
save_interval
==
0
:
model_path
=
os
.
path
.
join
(
config
.
model_save_dir
,
config
.
ARCHITECTURE
[
"name"
])
save_model
(
train_prog
,
model_path
,
"best_model"
)
# 3. save the persistable model
if
epoch_id
%
config
.
save_interval
==
0
:
model_path
=
os
.
path
.
join
(
config
.
model_save_dir
,
config
.
ARCHITECTURE
[
"name"
])
save_model
(
train_prog
,
model_path
,
epoch_id
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
args
=
parse_args
()
main
(
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录