提交 4bc70d81 编写于 作者: W wuzewu

Save infer model when saving checkpoint

上级 c63f0722
......@@ -27,6 +27,7 @@ import cv2
import yaml
import shutil
import paddleslim as slim
import paddle
import utils
import utils.logging as logging
......@@ -37,6 +38,15 @@ from nets import DeepLabv3p, ShuffleSeg, HRNet
import transforms as T
def save_infer_program(test_program, ckpt_dir):
_test_program = test_program.clone()
_test_program.desc.flush()
_test_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
f.write(_test_program.desc.serialize_to_string())
def dict2str(dict_input):
out = ''
for k, v in dict_input.items():
......@@ -244,6 +254,7 @@ class SegModel(object):
if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
save_infer_program(self.test_prog, save_dir)
model_info['status'] = 'Normal'
elif self.status == 'Quant':
fluid.save(self.test_prog, osp.join(save_dir, 'model'))
......
......@@ -21,5 +21,3 @@ import readers
from utils.utils import get_environ_info
env_info = get_environ_info()
log_level = 2
......@@ -30,6 +30,16 @@ from utils.utils import seconds_to_hms, get_environ_info
from utils.metrics import ConfusionMatrix
import transforms.transforms as T
import utils
import paddle
def save_infer_program(test_program, ckpt_dir):
_test_program = test_program.clone()
_test_program.desc.flush()
_test_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
f.write(_test_program.desc.serialize_to_string())
def dict2str(dict_input):
......@@ -238,6 +248,7 @@ class BaseModel(object):
if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
save_infer_program(self.test_prog, save_dir)
model_info['status'] = self.status
with open(
......
......@@ -16,7 +16,6 @@
import time
import os
import sys
import __init__
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
......@@ -25,10 +24,9 @@ def log(level=2, message=""):
current_time = time.time()
time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if __init__.log_level >= level:
print("{} [{}]\t{}".format(current_time, levels[level],
message).encode("utf-8").decode("latin1"))
sys.stdout.flush()
print("{} [{}]\t{}".format(current_time, levels[level],
message).encode("utf-8").decode("latin1"))
sys.stdout.flush()
def debug(message=""):
......
......@@ -27,6 +27,7 @@ import pprint
import random
import shutil
import paddle
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import profiler
......@@ -158,6 +159,15 @@ def load_checkpoint(exe, program):
return begin_epoch
def save_infer_program(test_program, ckpt_dir):
_test_program = test_program.clone()
_test_program.desc.flush()
_test_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(_test_program.desc)
with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f:
f.write(_test_program.desc.serialize_to_string())
def update_best_model(ckpt_dir):
best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model')
if os.path.exists(best_model_dir):
......@@ -173,6 +183,7 @@ def print_info(*msg):
def train(cfg):
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
if args.enable_ce:
startup_prog.random_seed = 1000
train_prog.random_seed = 1000
......@@ -224,6 +235,7 @@ def train(cfg):
data_loader, avg_loss, lr, pred, grts, masks = build_model(
train_prog, startup_prog, phase=ModelPhase.TRAIN)
build_model(test_prog, fluid.Program(), phase=ModelPhase.EVAL)
data_loader.set_sample_generator(
data_generator, batch_size=batch_size_per_dev, drop_last=drop_last)
......@@ -387,6 +399,7 @@ def train(cfg):
if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0
or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(train_prog, epoch)
save_infer_program(test_prog, ckpt_dir)
if args.do_eval:
print("Evaluation start")
......@@ -419,7 +432,8 @@ def train(cfg):
# save final model
if cfg.TRAINER_ID == 0:
save_checkpoint(train_prog, 'final')
ckpt_dir = save_checkpoint(train_prog, 'final')
save_infer_program(test_prog, ckpt_dir)
def main(args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册