diff --git a/contrib/HumanSeg/models/humanseg.py b/contrib/HumanSeg/models/humanseg.py index 413be8925e62cce21c55fbbf209573e8e0be16f8..1790e31e363197ebc337e57cac3d418ed0e40d2f 100644 --- a/contrib/HumanSeg/models/humanseg.py +++ b/contrib/HumanSeg/models/humanseg.py @@ -27,6 +27,7 @@ import cv2 import yaml import shutil import paddleslim as slim +import paddle import utils import utils.logging as logging @@ -37,6 +38,15 @@ from nets import DeepLabv3p, ShuffleSeg, HRNet import transforms as T +def save_infer_program(test_program, ckpt_dir): + _test_program = test_program.clone() + _test_program.desc.flush() + _test_program.desc._set_version() + paddle.fluid.core.save_op_compatible_info(_test_program.desc) + with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f: + f.write(_test_program.desc.serialize_to_string()) + + def dict2str(dict_input): out = '' for k, v in dict_input.items(): @@ -244,6 +254,7 @@ class SegModel(object): if self.status == 'Normal': fluid.save(self.train_prog, osp.join(save_dir, 'model')) + save_infer_program(self.test_prog, save_dir) model_info['status'] = 'Normal' elif self.status == 'Quant': fluid.save(self.test_prog, osp.join(save_dir, 'model')) diff --git a/contrib/RemoteSensing/__init__.py b/contrib/RemoteSensing/__init__.py index 6406dd3bed1e16be04df21067d188d70fae98026..fc8620cab54b60afdc992b93c388b115c19503e2 100644 --- a/contrib/RemoteSensing/__init__.py +++ b/contrib/RemoteSensing/__init__.py @@ -21,5 +21,3 @@ import readers from utils.utils import get_environ_info env_info = get_environ_info() - -log_level = 2 diff --git a/contrib/RemoteSensing/models/base.py b/contrib/RemoteSensing/models/base.py index 556c9ee0e7c3163930b605dad87c9fd22d1423bf..0b5c858f171626c8d95fe3934da992644de2e515 100644 --- a/contrib/RemoteSensing/models/base.py +++ b/contrib/RemoteSensing/models/base.py @@ -30,6 +30,16 @@ from utils.utils import seconds_to_hms, get_environ_info from utils.metrics import ConfusionMatrix import transforms.transforms as T import utils +import paddle + + +def save_infer_program(test_program, ckpt_dir): + _test_program = test_program.clone() + _test_program.desc.flush() + _test_program.desc._set_version() + paddle.fluid.core.save_op_compatible_info(_test_program.desc) + with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f: + f.write(_test_program.desc.serialize_to_string()) def dict2str(dict_input): @@ -238,6 +248,7 @@ class BaseModel(object): if self.status == 'Normal': fluid.save(self.train_prog, osp.join(save_dir, 'model')) + save_infer_program(self.test_prog, save_dir) model_info['status'] = self.status with open( diff --git a/contrib/RemoteSensing/utils/logging.py b/contrib/RemoteSensing/utils/logging.py index 64532505534ed8a238ff1b1c9ff4e53a59cba2c7..16670ca1d52293b5378480abda43a6e7b6456841 100644 --- a/contrib/RemoteSensing/utils/logging.py +++ b/contrib/RemoteSensing/utils/logging.py @@ -16,7 +16,6 @@ import time import os import sys -import __init__ levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} @@ -25,10 +24,9 @@ def log(level=2, message=""): current_time = time.time() time_array = time.localtime(current_time) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) - if __init__.log_level >= level: - print("{} [{}]\t{}".format(current_time, levels[level], - message).encode("utf-8").decode("latin1")) - sys.stdout.flush() + print("{} [{}]\t{}".format(current_time, levels[level], + message).encode("utf-8").decode("latin1")) + sys.stdout.flush() def debug(message=""): diff --git a/docs/model_export.md b/docs/model_export.md index 0c5247a5c0de775b54a695dfe4a425bb71bae23d..e2127546bafe83de31db46be74f8644ccbacb257 100644 --- a/docs/model_export.md +++ b/docs/model_export.md @@ -1,6 +1,6 @@ # 模型导出 -通过训练得到一个满足要求的模型后,如果想要将该模型接入到C++预测库或者Serving服务,我们需要通过`pdseg/export_model.py`来导出该模型。 +通过训练得到一个满足要求的模型后,如果想要将该模型接入到C++预测库或者Serving服务,我们需要通过[`pdseg/export_model.py`](../../pdseg/export_model.py)来导出该模型。 该脚本的使用方法和`train.py/eval.py/vis.py`完全一样。 diff --git a/pdseg/train.py b/pdseg/train.py index 7021ed9d6ae137c204e05e00d5a165f9859b056e..aa29de92ded7c79acd4ca41fa74ee6eaeb8993c2 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -27,6 +27,7 @@ import pprint import random import shutil +import paddle import numpy as np import paddle.fluid as fluid from paddle.fluid import profiler @@ -158,6 +159,15 @@ def load_checkpoint(exe, program): return begin_epoch +def save_infer_program(test_program, ckpt_dir): + _test_program = test_program.clone() + _test_program.desc.flush() + _test_program.desc._set_version() + paddle.fluid.core.save_op_compatible_info(_test_program.desc) + with open(os.path.join(ckpt_dir, 'model') + ".pdmodel", "wb") as f: + f.write(_test_program.desc.serialize_to_string()) + + def update_best_model(ckpt_dir): best_model_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, 'best_model') if os.path.exists(best_model_dir): @@ -173,6 +183,7 @@ def print_info(*msg): def train(cfg): startup_prog = fluid.Program() train_prog = fluid.Program() + test_prog = fluid.Program() if args.enable_ce: startup_prog.random_seed = 1000 train_prog.random_seed = 1000 @@ -224,6 +235,7 @@ def train(cfg): data_loader, avg_loss, lr, pred, grts, masks = build_model( train_prog, startup_prog, phase=ModelPhase.TRAIN) + build_model(test_prog, fluid.Program(), phase=ModelPhase.EVAL) data_loader.set_sample_generator( data_generator, batch_size=batch_size_per_dev, drop_last=drop_last) @@ -387,6 +399,7 @@ def train(cfg): if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0: ckpt_dir = save_checkpoint(train_prog, epoch) + save_infer_program(test_prog, ckpt_dir) if args.do_eval: print("Evaluation start") @@ -419,7 +432,8 @@ def train(cfg): # save final model if cfg.TRAINER_ID == 0: - save_checkpoint(train_prog, 'final') + ckpt_dir = save_checkpoint(train_prog, 'final') + save_infer_program(test_prog, ckpt_dir) def main(args): diff --git a/requirements.txt b/requirements.txt index cacf4a6ca2a8b77869efa2f9dfcdca1f545fa211..6200a94b71c2ce40d02bf68d97f642cde42836e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ pre-commit yapf == 0.26.0 flake8 pyyaml >= 5.1 -visualdl == 2.0.0b4 +visualdl >= 2.0.0