Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
4bc70d81
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
286
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4bc70d81
编写于
6月 18, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Save infer model when saving checkpoint
上级
c63f0722
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
40 addition
and
8 deletion
+40
-8
contrib/HumanSeg/models/humanseg.py
contrib/HumanSeg/models/humanseg.py
+11
-0
contrib/RemoteSensing/__init__.py
contrib/RemoteSensing/__init__.py
+0
-2
contrib/RemoteSensing/models/base.py
contrib/RemoteSensing/models/base.py
+11
-0
contrib/RemoteSensing/utils/logging.py
contrib/RemoteSensing/utils/logging.py
+3
-5
pdseg/train.py
pdseg/train.py
+15
-1
未找到文件。
contrib/HumanSeg/models/humanseg.py
浏览文件 @
4bc70d81
...
...
@@ -27,6 +27,7 @@ import cv2
import
yaml
import
shutil
import
paddleslim
as
slim
import
paddle
import
utils
import
utils.logging
as
logging
...
...
@@ -37,6 +38,15 @@ from nets import DeepLabv3p, ShuffleSeg, HRNet
import
transforms
as
T
def
save_infer_program
(
test_program
,
ckpt_dir
):
_test_program
=
test_program
.
clone
()
_test_program
.
desc
.
flush
()
_test_program
.
desc
.
_set_version
()
paddle
.
fluid
.
core
.
save_op_compatible_info
(
_test_program
.
desc
)
with
open
(
os
.
path
.
join
(
ckpt_dir
,
'model'
)
+
".pdmodel"
,
"wb"
)
as
f
:
f
.
write
(
_test_program
.
desc
.
serialize_to_string
())
def
dict2str
(
dict_input
):
out
=
''
for
k
,
v
in
dict_input
.
items
():
...
...
@@ -244,6 +254,7 @@ class SegModel(object):
if
self
.
status
==
'Normal'
:
fluid
.
save
(
self
.
train_prog
,
osp
.
join
(
save_dir
,
'model'
))
save_infer_program
(
self
.
test_prog
,
save_dir
)
model_info
[
'status'
]
=
'Normal'
elif
self
.
status
==
'Quant'
:
fluid
.
save
(
self
.
test_prog
,
osp
.
join
(
save_dir
,
'model'
))
...
...
contrib/RemoteSensing/__init__.py
浏览文件 @
4bc70d81
...
...
@@ -21,5 +21,3 @@ import readers
from
utils.utils
import
get_environ_info
env_info
=
get_environ_info
()
log_level
=
2
contrib/RemoteSensing/models/base.py
浏览文件 @
4bc70d81
...
...
@@ -30,6 +30,16 @@ from utils.utils import seconds_to_hms, get_environ_info
from
utils.metrics
import
ConfusionMatrix
import
transforms.transforms
as
T
import
utils
import
paddle
def
save_infer_program
(
test_program
,
ckpt_dir
):
_test_program
=
test_program
.
clone
()
_test_program
.
desc
.
flush
()
_test_program
.
desc
.
_set_version
()
paddle
.
fluid
.
core
.
save_op_compatible_info
(
_test_program
.
desc
)
with
open
(
os
.
path
.
join
(
ckpt_dir
,
'model'
)
+
".pdmodel"
,
"wb"
)
as
f
:
f
.
write
(
_test_program
.
desc
.
serialize_to_string
())
def
dict2str
(
dict_input
):
...
...
@@ -238,6 +248,7 @@ class BaseModel(object):
if
self
.
status
==
'Normal'
:
fluid
.
save
(
self
.
train_prog
,
osp
.
join
(
save_dir
,
'model'
))
save_infer_program
(
self
.
test_prog
,
save_dir
)
model_info
[
'status'
]
=
self
.
status
with
open
(
...
...
contrib/RemoteSensing/utils/logging.py
浏览文件 @
4bc70d81
...
...
@@ -16,7 +16,6 @@
import
time
import
os
import
sys
import
__init__
levels
=
{
0
:
'ERROR'
,
1
:
'WARNING'
,
2
:
'INFO'
,
3
:
'DEBUG'
}
...
...
@@ -25,7 +24,6 @@ def log(level=2, message=""):
current_time
=
time
.
time
()
time_array
=
time
.
localtime
(
current_time
)
current_time
=
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time_array
)
if
__init__
.
log_level
>=
level
:
print
(
"{} [{}]
\t
{}"
.
format
(
current_time
,
levels
[
level
],
message
).
encode
(
"utf-8"
).
decode
(
"latin1"
))
sys
.
stdout
.
flush
()
...
...
pdseg/train.py
浏览文件 @
4bc70d81
...
...
@@ -27,6 +27,7 @@ import pprint
import
random
import
shutil
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
profiler
...
...
@@ -158,6 +159,15 @@ def load_checkpoint(exe, program):
return
begin_epoch
def
save_infer_program
(
test_program
,
ckpt_dir
):
_test_program
=
test_program
.
clone
()
_test_program
.
desc
.
flush
()
_test_program
.
desc
.
_set_version
()
paddle
.
fluid
.
core
.
save_op_compatible_info
(
_test_program
.
desc
)
with
open
(
os
.
path
.
join
(
ckpt_dir
,
'model'
)
+
".pdmodel"
,
"wb"
)
as
f
:
f
.
write
(
_test_program
.
desc
.
serialize_to_string
())
def
update_best_model
(
ckpt_dir
):
best_model_dir
=
os
.
path
.
join
(
cfg
.
TRAIN
.
MODEL_SAVE_DIR
,
'best_model'
)
if
os
.
path
.
exists
(
best_model_dir
):
...
...
@@ -173,6 +183,7 @@ def print_info(*msg):
def
train
(
cfg
):
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
test_prog
=
fluid
.
Program
()
if
args
.
enable_ce
:
startup_prog
.
random_seed
=
1000
train_prog
.
random_seed
=
1000
...
...
@@ -224,6 +235,7 @@ def train(cfg):
data_loader
,
avg_loss
,
lr
,
pred
,
grts
,
masks
=
build_model
(
train_prog
,
startup_prog
,
phase
=
ModelPhase
.
TRAIN
)
build_model
(
test_prog
,
fluid
.
Program
(),
phase
=
ModelPhase
.
EVAL
)
data_loader
.
set_sample_generator
(
data_generator
,
batch_size
=
batch_size_per_dev
,
drop_last
=
drop_last
)
...
...
@@ -387,6 +399,7 @@ def train(cfg):
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
train_prog
,
epoch
)
save_infer_program
(
test_prog
,
ckpt_dir
)
if
args
.
do_eval
:
print
(
"Evaluation start"
)
...
...
@@ -419,7 +432,8 @@ def train(cfg):
# save final model
if
cfg
.
TRAINER_ID
==
0
:
save_checkpoint
(
train_prog
,
'final'
)
ckpt_dir
=
save_checkpoint
(
train_prog
,
'final'
)
save_infer_program
(
test_prog
,
ckpt_dir
)
def
main
(
args
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录