Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
7a65af0c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7a65af0c
编写于
11月 16, 2020
作者:
W
wangguanzhong
提交者:
GitHub
11月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update save load (#1702)
上级
48e21f3c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
160 addition
and
78 deletion
+160
-78
ppdet/utils/checkpoint.py
ppdet/utils/checkpoint.py
+108
-59
tools/eval.py
tools/eval.py
+16
-2
tools/infer.py
tools/infer.py
+2
-2
tools/train.py
tools/train.py
+34
-15
未找到文件。
ppdet/utils/checkpoint.py
浏览文件 @
7a65af0c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -11,89 +25,124 @@ import numpy as np
import
paddle
import
paddle.fluid
as
fluid
from
.download
import
get_weights_path
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
get_ckpt_path
(
path
):
if
path
.
startswith
(
'http://'
)
or
path
.
startswith
(
'https://'
):
env
=
os
.
environ
if
'PADDLE_TRAINERS_NUM'
in
env
and
'PADDLE_TRAINER_ID'
in
env
:
trainer_id
=
int
(
env
[
'PADDLE_TRAINER_ID'
])
num_trainers
=
int
(
env
[
'PADDLE_TRAINERS_NUM'
])
if
num_trainers
<=
1
:
path
=
get_weights_path
(
path
)
else
:
from
ppdet.utils.download
import
map_path
,
WEIGHTS_HOME
weight_path
=
map_path
(
path
,
WEIGHTS_HOME
)
lock_path
=
weight_path
+
'.lock'
if
not
os
.
path
.
exists
(
weight_path
):
try
:
os
.
makedirs
(
os
.
path
.
dirname
(
weight_path
))
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
with
open
(
lock_path
,
'w'
):
# touch
os
.
utime
(
lock_path
,
None
)
if
trainer_id
==
0
:
get_weights_path
(
path
)
os
.
remove
(
lock_path
)
else
:
while
os
.
path
.
exists
(
lock_path
):
time
.
sleep
(
1
)
path
=
weight_path
else
:
def
is_url
(
path
):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return
path
.
startswith
(
'http://'
)
or
path
.
startswith
(
'https://'
)
def
get_weight_path
(
path
):
env
=
os
.
environ
if
'PADDLE_TRAINERS_NUM'
in
env
and
'PADDLE_TRAINER_ID'
in
env
:
trainer_id
=
int
(
env
[
'PADDLE_TRAINER_ID'
])
num_trainers
=
int
(
env
[
'PADDLE_TRAINERS_NUM'
])
if
num_trainers
<=
1
:
path
=
get_weights_path
(
path
)
else
:
from
ppdet.utils.download
import
map_path
,
WEIGHTS_HOME
weight_path
=
map_path
(
path
,
WEIGHTS_HOME
)
lock_path
=
weight_path
+
'.lock'
if
not
os
.
path
.
exists
(
weight_path
):
try
:
os
.
makedirs
(
os
.
path
.
dirname
(
weight_path
))
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
with
open
(
lock_path
,
'w'
):
# touch
os
.
utime
(
lock_path
,
None
)
if
trainer_id
==
0
:
get_weights_path
(
path
)
os
.
remove
(
lock_path
)
else
:
while
os
.
path
.
exists
(
lock_path
):
time
.
sleep
(
1
)
path
=
weight_path
else
:
path
=
get_weights_path
(
path
)
return
path
def
_strip_postfix
(
path
):
path
,
ext
=
os
.
path
.
splitext
(
path
)
assert
ext
in
[
''
,
'.pdparams'
,
'.pdopt'
,
'.pdmodel'
],
\
"Unknown postfix {} from weights"
.
format
(
ext
)
return
path
def
load_dygraph_ckpt
(
model
,
optimizer
=
None
,
pretrain_ckpt
=
None
,
ckpt
=
None
,
ckpt_type
=
None
,
exclude_params
=
[],
load_static_weights
=
False
):
def
load_weight
(
model
,
weight
,
optimizer
=
None
):
if
is_url
(
weight
):
weight
=
get_weight_path
(
weight
)
path
=
_strip_postfix
(
weight
)
pdparam_path
=
path
+
'.pdparams'
if
not
os
.
path
.
exists
(
pdparam_path
):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
pdparam_path
))
param_state_dict
=
paddle
.
load
(
pdparam_path
)
model
.
set_dict
(
param_state_dict
)
if
optimizer
is
not
None
and
os
.
path
.
exists
(
path
+
'.pdopt'
):
optim_state_dict
=
paddle
.
load
(
path
+
'.pdopt'
)
optimizer
.
set_state_dict
(
optim_state_dict
)
return
def
load_pretrain_weight
(
model
,
pretrain_weight
,
load_static_weights
=
False
,
weight_type
=
'pretrain'
):
assert
weight_type
in
[
'pretrain'
,
'finetune'
]
if
is_url
(
pretrain_weight
):
pretrain_weight
=
get_weight_path
(
pretrain_weight
)
path
=
_strip_postfix
(
pretrain_weight
)
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
isfile
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
model_dict
=
model
.
state_dict
()
assert
ckpt_type
in
[
'pretrain'
,
'resume'
,
'finetune'
,
None
]
if
ckpt_type
==
'pretrain'
and
ckpt
is
None
:
ckpt
=
pretrain_ckpt
ckpt
=
get_ckpt_path
(
ckpt
)
assert
os
.
path
.
exists
(
ckpt
),
"Path {} does not exist."
.
format
(
ckpt
)
if
load_static_weights
:
pre_state_dict
=
fluid
.
load_program_state
(
ckpt
)
pre_state_dict
=
paddle
.
static
.
load_program_state
(
path
)
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
for
key
in
model_dict
.
keys
():
weight_name
=
model_dict
[
key
].
name
if
weight_name
in
pre_state_dict
.
keys
():
print
(
'Load weight: {}, shape: {}'
.
format
(
logger
.
info
(
'Load weight: {}, shape: {}'
.
format
(
weight_name
,
pre_state_dict
[
weight_name
].
shape
))
param_state_dict
[
key
]
=
pre_state_dict
[
weight_name
]
else
:
param_state_dict
[
key
]
=
model_dict
[
key
]
model
.
set_dict
(
param_state_dict
)
return
model
param_state_dict
,
optim_state_dict
=
fluid
.
load_dygraph
(
ckpt
)
return
if
len
(
exclude_params
)
!=
0
:
for
k
in
exclude_params
:
param_state_dict
.
pop
(
k
,
None
)
if
ckpt_type
==
'pretrain'
:
param_state_dict
=
paddle
.
load
(
path
+
'.pdparams'
)
if
weight_type
==
'pretrain'
:
model
.
backbone
.
set_dict
(
param_state_dict
)
else
:
ignore_set
=
set
()
for
name
,
weight
in
model_dict
:
if
name
in
param_state_dict
:
if
weight
.
shape
!=
param_state_dict
[
name
].
shape
:
param_state_dict
.
pop
(
name
,
None
)
model
.
set_dict
(
param_state_dict
)
if
ckpt_type
==
'resume'
:
assert
optim_state_dict
,
"Can't Resume Last Training's Optimizer State!!!"
optimizer
.
set_dict
(
optim_state_dict
)
return
model
return
def
save_
dygraph_ckpt
(
model
,
optimizer
,
save_dir
,
save_name
):
def
save_
model
(
model
,
optimizer
,
save_dir
,
save_name
):
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
save_path
=
os
.
path
.
join
(
save_dir
,
save_name
)
fluid
.
dygraph
.
save_dygraph
(
model
.
state_dict
(),
save_path
)
fluid
.
dygraph
.
save_dygraph
(
optimizer
.
state_dict
(),
save_path
)
print
(
"Save checkpoint:"
,
save_dir
)
paddle
.
save
(
model
.
state_dict
(),
save_path
+
".pdparams"
)
paddle
.
save
(
optimizer
.
state_dict
(),
save_path
+
".pdopt"
)
logger
.
info
(
"Save checkpoint: {}"
.
format
(
save_dir
)
)
tools/eval.py
浏览文件 @
7a65af0c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -19,7 +33,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.eval_utils
import
get_infer_results
,
eval_results
from
ppdet.utils.checkpoint
import
load_
dygraph_ckpt
,
save_dygraph_ckp
t
from
ppdet.utils.checkpoint
import
load_
weigh
t
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
...
...
@@ -51,7 +65,7 @@ def run(FLAGS, cfg, place):
model
=
create
(
cfg
.
architecture
)
# Init Model
model
=
load_dygraph_ckpt
(
model
,
ckpt
=
cfg
.
weights
)
load_weight
(
model
,
cfg
.
weights
)
# Data Reader
dataset
=
cfg
.
EvalDataset
...
...
tools/infer.py
浏览文件 @
7a65af0c
...
...
@@ -34,7 +34,7 @@ from ppdet.utils.check import check_gpu, check_version, check_config
from
ppdet.utils.visualizer
import
visualize_results
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.data.reader
import
create_reader
from
ppdet.utils.checkpoint
import
load_
dygraph_ckp
t
from
ppdet.utils.checkpoint
import
load_
weigh
t
from
ppdet.utils.eval_utils
import
get_infer_results
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
...
...
@@ -141,7 +141,7 @@ def run(FLAGS, cfg):
use_default_label
)
# Init Model
model
=
load_dygraph_ckpt
(
model
,
ckpt
=
cfg
.
weights
)
load_weight
(
model
,
cfg
.
weights
)
# Data Reader
test_reader
=
create_reader
(
cfg
.
TestDataset
,
cfg
.
TestReader
)
...
...
tools/train.py
浏览文件 @
7a65af0c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -21,7 +35,7 @@ from ppdet.core.workspace import load_config, merge_config, create
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.checkpoint
import
load_
dygraph_ckpt
,
save_dygraph_ckpt
from
ppdet.utils.checkpoint
import
load_
weight
,
load_pretrain_weight
,
save_model
from
paddle.distributed
import
ParallelEnv
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
...
...
@@ -32,7 +46,7 @@ logger = logging.getLogger(__name__)
def
parse_args
():
parser
=
ArgsParser
()
parser
.
add_argument
(
"-
ckp
t_type"
,
"-
-weigh
t_type"
,
default
=
'pretrain'
,
type
=
str
,
help
=
"Loading Checkpoints only support 'pretrain', 'finetune', 'resume'."
...
...
@@ -116,12 +130,12 @@ def run(FLAGS, cfg, place):
optimizer
=
create
(
'OptimizerBuilder'
)(
lr
,
model
.
parameters
())
# Init Model & Optimzer
model
=
load_dygraph_ckpt
(
model
,
optimizer
,
cfg
.
pretrain_weights
,
ckpt_type
=
FLAGS
.
ckpt_type
,
load_static_weights
=
cfg
.
get
(
'load_static_weights'
,
False
)
)
if
FLAGS
.
weight_type
==
'resume'
:
load_weight
(
model
,
cfg
.
pretrain_weights
,
optimizer
)
else
:
load_pretrain_weight
(
model
,
cfg
.
pretrain_weights
,
cfg
.
get
(
'load_static_weights'
,
False
)
,
FLAGS
.
weight_type
)
# Parallel Model
if
ParallelEnv
().
nranks
>
1
:
...
...
@@ -132,13 +146,17 @@ def run(FLAGS, cfg, place):
time_stat
=
deque
(
maxlen
=
cfg
.
log_iter
)
start_time
=
time
.
time
()
end_time
=
time
.
time
()
# Run Train
start_epoch
=
optimizer
.
state_dict
()[
'LR_Scheduler'
][
'last_epoch'
]
for
e_id
in
range
(
int
(
cfg
.
epoch
)):
cur_eid
=
e_id
+
start_epoch
for
iter_id
,
data
in
enumerate
(
train_loader
):
start_time
=
end_time
end_time
=
time
.
time
()
time_stat
.
append
(
end_time
-
start_time
)
time_cost
=
np
.
mean
(
time_stat
)
eta_sec
=
(
cfg
.
epoch
*
step_per_epoch
-
iter_id
)
*
time_cost
eta_sec
=
(
(
cfg
.
epoch
-
cur_eid
)
*
step_per_epoch
-
iter_id
)
*
time_cost
eta
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
)))
# Model Forward
...
...
@@ -162,22 +180,23 @@ def run(FLAGS, cfg, place):
if
ParallelEnv
().
nranks
<
2
or
ParallelEnv
().
local_rank
==
0
:
# Log state
if
iter_id
==
0
:
if
e_id
==
0
and
iter_id
==
0
:
train_stats
=
TrainingStats
(
cfg
.
log_iter
,
outputs
.
keys
())
train_stats
.
update
(
outputs
)
logs
=
train_stats
.
log
()
if
iter_id
%
cfg
.
log_iter
==
0
:
strs
=
'Epoch:{}: iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'
.
format
(
e_id
,
iter_id
,
curr_lr
,
logs
,
time_cost
,
eta
)
ips
=
float
(
cfg
[
'TrainReader'
][
'batch_size'
])
/
time_cost
strs
=
'Epoch:{}: iter: {}, lr: {:.6f}, {}, eta: {}, batch_cost: {:.5f} sec, ips: {:.5f} images/sec'
.
format
(
cur_eid
,
iter_id
,
curr_lr
,
logs
,
eta
,
time_cost
,
ips
)
logger
.
info
(
strs
)
# Save Stage
if
ParallelEnv
().
local_rank
==
0
and
e_
id
%
cfg
.
snapshot_epoch
==
0
:
if
ParallelEnv
().
local_rank
==
0
and
cur_e
id
%
cfg
.
snapshot_epoch
==
0
:
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_name
=
str
(
e_id
+
1
)
if
e_
id
+
1
!=
int
(
save_name
=
str
(
cur_eid
)
if
cur_e
id
+
1
!=
int
(
cfg
.
epoch
)
else
"model_final"
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
save_
dygraph_ckpt
(
model
,
optimizer
,
save_dir
,
save_name
)
save_
model
(
model
,
optimizer
,
save_dir
,
save_name
)
def
main
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录