提交 4bc70d81 编写于 作者: W wuzewu

Save infer model when saving checkpoint

上级 c63f0722
...@@ -27,6 +27,7 @@ import cv2 ...@@ -27,6 +27,7 @@ import cv2
import yaml import yaml
import shutil import shutil
import paddleslim as slim import paddleslim as slim
import paddle
import utils import utils
import utils.logging as logging import utils.logging as logging
...@@ -37,6 +38,15 @@ from nets import DeepLabv3p, ShuffleSeg, HRNet ...@@ -37,6 +38,15 @@ from nets import DeepLabv3p, ShuffleSeg, HRNet
import transforms as T import transforms as T
def save_infer_program(test_program, ckpt_dir):
_test_program = test_program.clone()
_test_program.desc.flush()
_test_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
f.write(_test_program.desc.serialize_to_string())
def dict2str(dict_input): def dict2str(dict_input):
out = '' out = ''
for k, v in dict_input.items(): for k, v in dict_input.items():
...@@ -244,6 +254,7 @@ class SegModel(object): ...@@ -244,6 +254,7 @@ class SegModel(object):
if self.status == 'Normal': if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model')) fluid.save(self.train_prog, osp.join(save_dir, 'model'))
save_infer_program(self.test_prog, save_dir)
model_info['status'] = 'Normal' model_info['status'] = 'Normal'
elif self.status == 'Quant': elif self.status == 'Quant':
fluid.save(self.test_prog, osp.join(save_dir, 'model')) fluid.save(self.test_prog, osp.join(save_dir, 'model'))
......
...@@ -21,5 +21,3 @@ import readers ...@@ -21,5 +21,3 @@ import readers
from utils.utils import get_environ_info from utils.utils import get_environ_info
env_info = get_environ_info() env_info = get_environ_info()
log_level = 2
...@@ -30,6 +30,16 @@ from utils.utils import seconds_to_hms, get_environ_info ...@@ -30,6 +30,16 @@ from utils.utils import seconds_to_hms, get_environ_info
from utils.metrics import ConfusionMatrix from utils.metrics import ConfusionMatrix
import transforms.transforms as T import transforms.transforms as T
import utils import utils
import paddle
def save_infer_program(test_program, ckpt_dir):
_test_program = test_program.clone()
_test_program.desc.flush()
_test_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
f.write(_test_program.desc.serialize_to_string())
def dict2str(dict_input): def dict2str(dict_input):
...@@ -238,6 +248,7 @@ class BaseModel(object): ...@@ -238,6 +248,7 @@ class BaseModel(object):
if self.status == 'Normal': if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model')) fluid.save(self.train_prog, osp.join(save_dir, 'model'))
save_infer_program(self.test_prog, save_dir)
model_info['status'] = self.status model_info['status'] = self.status
with open( with open(
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import time import time
import os import os
import sys import sys
import __init__
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
...@@ -25,10 +24,9 @@ def log(level=2, message=""): ...@@ -25,10 +24,9 @@ def log(level=2, message=""):
current_time = time.time() current_time = time.time()
time_array = time.localtime(current_time) time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if __init__.log_level >= level: print("{} [{}]\t{}".format(current_time, levels[level],
print("{} [{}]\t{}".format(current_time, levels[level], message).encode("utf-8").decode("latin1"))
message).encode("utf-8").decode("latin1")) sys.stdout.flush()
sys.stdout.flush()
def debug(message=""): def debug(message=""):
......
...@@ -27,6 +27,7 @@ import pprint ...@@ -27,6 +27,7 @@ import pprint
import random import random
import shutil import shutil
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
...@@ -158,6 +159,15 @@ def load_checkpoint(exe, program): ...@@ -158,6 +159,15 @@ def load_checkpoint(exe, program):
return begin_epoch return begin_epoch
def save_infer_program(test_program, ckpt_dir):
_test_program = test_program.clone()
_test_program.desc.flush()
_test_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
f.write(_test_program.desc.serialize_to_string())
def update_best_model(ckpt_dir): def update_best_model(ckpt_dir):
best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')
if os.path.exists(best_model_dir): if os.path.exists(best_model_dir):
...@@ -173,6 +183,7 @@ def print_info(*msg): ...@@ -173,6 +183,7 @@ def print_info(*msg):
def train(cfg): def train(cfg):
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
test_prog = fluid.Program()
if args.enable_ce: if args.enable_ce:
startup_prog.random_seed = 1000 startup_prog.random_seed = 1000
train_prog.random_seed = 1000 train_prog.random_seed = 1000
...@@ -224,6 +235,7 @@ def train(cfg): ...@@ -224,6 +235,7 @@ def train(cfg):
data_loader, avg_loss, lr, pred, grts, masks = build_model( data_loader, avg_loss, lr, pred, grts, masks = build_model(
train_prog, startup_prog, phase=ModelPhase.TRAIN) train_prog, startup_prog, phase=ModelPhase.TRAIN)
build_model(test_prog, fluid.Program(), phase=ModelPhase.EVAL)
data_loader.set_sample_generator( data_loader.set_sample_generator(
data_generator, batch_size=batch_size_per_dev, drop_last=drop_last) data_generator, batch_size=batch_size_per_dev, drop_last=drop_last)
...@@ -387,6 +399,7 @@ def train(cfg): ...@@ -387,6 +399,7 @@ def train(cfg):
if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0
or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0: or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(train_prog, epoch) ckpt_dir = save_checkpoint(train_prog, epoch)
save_infer_program(test_prog, ckpt_dir)
if args.do_eval: if args.do_eval:
print("Evaluation start") print("Evaluation start")
...@@ -419,7 +432,8 @@ def train(cfg): ...@@ -419,7 +432,8 @@ def train(cfg):
# save final model # save final model
if cfg.TRAINER_ID == 0: if cfg.TRAINER_ID == 0:
save_checkpoint(train_prog, 'final') ckpt_dir = save_checkpoint(train_prog, 'final')
save_infer_program(test_prog, ckpt_dir)
def main(args): def main(args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册