未验证 提交 bd77dc75 编写于 作者: L Lin Manhui 提交者: GitHub

Toward Devkit Consistency (#8360)

* Accommodate UAPI

* Fix bugs

* Set defaults for use_fd_format

* Restore visualize.py

* Rename variable

* Optimize save_dir

* Fix mistakenly update

* Add format check
上级 2fb66706
......@@ -8,9 +8,11 @@ QAT:
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
epoch: 5
TrainReader:
batch_size: 1
LearningRate:
base_lr: 0.001
schedulers:
......
......@@ -10,6 +10,9 @@ QAT:
epoch: 50
TrainReader:
batch_size: 8
LearningRate:
base_lr: 0.0001
schedulers:
......
......@@ -101,8 +101,9 @@ class Detector(object):
enable_mkldnn_bfloat16=False,
output_dir='output',
threshold=0.5,
delete_shuffle_pass=False):
self.pred_config = self.set_config(model_dir)
delete_shuffle_pass=False,
use_fd_format=False):
self.pred_config = self.set_config(model_dir, use_fd_format=use_fd_format)
self.predictor, self.config = load_predictor(
model_dir,
self.pred_config.arch,
......@@ -125,8 +126,8 @@ class Detector(object):
self.output_dir = output_dir
self.threshold = threshold
def set_config(self, model_dir):
return PredictConfig(model_dir)
def set_config(self, model_dir, use_fd_format):
return PredictConfig(model_dir, use_fd_format=use_fd_format)
def preprocess(self, image_list):
preprocess_ops = []
......@@ -560,7 +561,8 @@ class DetectorSOLOv2(Detector):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5, ):
threshold=0.5,
use_fd_format=False):
super(DetectorSOLOv2, self).__init__(
model_dir=model_dir,
device=device,
......@@ -574,7 +576,8 @@ class DetectorSOLOv2(Detector):
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold, )
threshold=threshold,
use_fd_format=use_fd_format)
def predict(self, repeats=1, run_benchmark=False):
'''
......@@ -650,7 +653,8 @@ class DetectorPicoDet(Detector):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5, ):
threshold=0.5,
use_fd_format=False):
super(DetectorPicoDet, self).__init__(
model_dir=model_dir,
device=device,
......@@ -664,7 +668,8 @@ class DetectorPicoDet(Detector):
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold, )
threshold=threshold,
use_fd_format=use_fd_format)
def postprocess(self, inputs, result):
# postprocess output of predictor
......@@ -745,7 +750,8 @@ class DetectorCLRNet(Detector):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5, ):
threshold=0.5,
use_fd_format=False):
super(DetectorCLRNet, self).__init__(
model_dir=model_dir,
device=device,
......@@ -759,7 +765,8 @@ class DetectorCLRNet(Detector):
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold, )
threshold=threshold,
use_fd_format=use_fd_format)
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
......@@ -867,9 +874,24 @@ class PredictConfig():
model_dir (str): root path of model.yml
"""
def __init__(self, model_dir):
def __init__(self, model_dir, use_fd_format=False):
# parsing Yaml config for Preprocess
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
fd_deploy_file = os.path.join(model_dir, 'inference.yml')
ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
if use_fd_format:
if not os.path.exists(fd_deploy_file) and os.path.exists(
ppdet_deploy_file):
raise RuntimeError(
"Non-FD format model detected. Please set `use_fd_format` to False."
)
deploy_file = fd_deploy_file
else:
if not os.path.exists(ppdet_deploy_file) and os.path.exists(
fd_deploy_file):
raise RuntimeError(
"FD format model detected. Please set `use_fd_format` to False."
)
deploy_file = ppdet_deploy_file
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
......@@ -1121,7 +1143,10 @@ def print_arguments(args):
def main():
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
if FLAGS.use_fd_format:
deploy_file = os.path.join(FLAGS.model_dir, 'inference.yml')
else:
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
......@@ -1146,7 +1171,8 @@ def main():
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
output_dir=FLAGS.output_dir,
use_fd_format=FLAGS.use_fd_format)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
......@@ -76,7 +76,8 @@ class KeyPointDetector(Detector):
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
use_dark=True):
use_dark=True,
use_fd_format=False):
super(KeyPointDetector, self).__init__(
model_dir=model_dir,
device=device,
......@@ -89,11 +90,12 @@ class KeyPointDetector(Detector):
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
threshold=threshold,
use_fd_format=use_fd_format)
self.use_dark = use_dark
def set_config(self, model_dir):
return PredictConfig_KeyPoint(model_dir)
def set_config(self, model_dir, use_fd_format):
return PredictConfig_KeyPoint(model_dir, use_fd_format=use_fd_format)
def get_person_from_rect(self, image, results):
# crop the person result from image
......@@ -302,9 +304,24 @@ class PredictConfig_KeyPoint():
model_dir (str): root path of model.yml
"""
def __init__(self, model_dir):
def __init__(self, model_dir, use_fd_format=False):
# parsing Yaml config for Preprocess
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
fd_deploy_file = os.path.join(model_dir, 'inference.yml')
ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
if use_fd_format:
if not os.path.exists(fd_deploy_file) and os.path.exists(
ppdet_deploy_file):
raise RuntimeError(
"Non-FD format model detected. Please set `use_fd_format` to False."
)
deploy_file = fd_deploy_file
else:
if not os.path.exists(ppdet_deploy_file) and os.path.exists(
fd_deploy_file):
raise RuntimeError(
"FD format model detected. Please set `use_fd_format` to False."
)
deploy_file = ppdet_deploy_file
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
......@@ -368,7 +385,8 @@ def main():
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
use_dark=FLAGS.use_dark)
use_dark=FLAGS.use_dark,
use_fd_format=FLAGS.use_fd_format)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
......@@ -211,6 +211,7 @@ def argsparser():
type=str,
default="shape_range_info.pbtxt",
help="Path of a dynamic shape file for tensorrt.")
parser.add_argument("--use_fd_format", action="store_true")
return parser
......
......@@ -160,8 +160,7 @@ class Checkpointer(Callback):
def __init__(self, model):
super(Checkpointer, self).__init__(model)
self.best_ap = -1000.
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
self.save_dir = self.model.cfg.save_dir
if hasattr(self.model.model, 'student_model'):
self.weight = self.model.model.student_model
else:
......@@ -323,8 +322,7 @@ class WandbCallback(Callback):
raise e
self.wandb_params = model.cfg.get('wandb', None)
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
self.save_dir = self.model.cfg.save_dir
if self.wandb_params is None:
self.wandb_params = {}
for k, v in model.cfg.items():
......
......@@ -1096,7 +1096,10 @@ class Trainer(object):
def _get_infer_cfg_and_input_spec(self,
save_dir,
prune_input=True,
kl_quant=False):
kl_quant=False,
yaml_name=None):
if yaml_name is None:
yaml_name = 'infer_cfg.yml'
image_shape = None
im_shape = [None, 2]
scale_factor = [None, 2]
......@@ -1147,7 +1150,7 @@ class Trainer(object):
# Save infer cfg
_dump_infer_config(self.cfg,
os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
os.path.join(save_dir, yaml_name), image_shape,
self.model)
input_spec = [{
......@@ -1203,7 +1206,7 @@ class Trainer(object):
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
def export(self, output_dir='output_inference', for_fd=False):
if hasattr(self.model, 'aux_neck'):
self.model.__delattr__('aux_neck')
if hasattr(self.model, 'aux_head'):
......@@ -1211,23 +1214,31 @@ class Trainer(object):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if for_fd:
save_dir = output_dir
save_name = 'inference'
yaml_name = 'inference.yml'
else:
save_dir = os.path.join(output_dir, model_name)
save_name = 'model'
yaml_name = None
if not os.path.exists(save_dir):
os.makedirs(save_dir)
static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
save_dir)
save_dir, yaml_name=yaml_name)
# dy2st and save model
if 'slim' not in self.cfg or 'QAT' not in self.cfg['slim_type']:
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
os.path.join(save_dir, save_name),
input_spec=pruned_input_spec)
else:
self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
os.path.join(save_dir, save_name),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
......
......@@ -300,17 +300,27 @@ def save_model(model,
"""
if paddle.distributed.get_rank() != 0:
return
save_dir = os.path.normpath(save_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if save_name == "best_model":
best_model_path = os.path.join(save_dir, 'best_model')
if not os.path.exists(best_model_path):
os.makedirs(best_model_path)
save_path = os.path.join(save_dir, save_name)
# save model
if isinstance(model, nn.Layer):
paddle.save(model.state_dict(), save_path + ".pdparams")
best_model = model.state_dict()
else:
assert isinstance(model,
dict), 'model is not a instance of nn.layer or dict'
if ema_model is None:
paddle.save(model, save_path + ".pdparams")
best_model = model
else:
assert isinstance(ema_model,
dict), ("ema_model is not a instance of dict, "
......@@ -318,6 +328,11 @@ def save_model(model,
# Exchange model and ema_model to save
paddle.save(ema_model, save_path + ".pdparams")
paddle.save(model, save_path + ".pdema")
best_model = ema_model
if save_name == 'best_model':
best_model_path = os.path.join(best_model_path, 'model')
paddle.save(best_model, best_model_path + ".pdparams")
# save optimizer
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
......
......@@ -56,6 +56,7 @@ def parse_args():
default=None,
type=str,
help="Configuration file of slim method.")
parser.add_argument("--for_fd", action='store_true')
args = parser.parse_args()
return args
......@@ -76,9 +77,10 @@ def run(FLAGS, cfg):
trainer.load_weights(cfg.weights)
# export model
trainer.export(FLAGS.output_dir)
trainer.export(FLAGS.output_dir, for_fd=FLAGS.for_fd)
if FLAGS.export_serving_model:
assert not FLAGS.for_fd
from paddle_serving_client.io import inference_model_to_serving
model_name = os.path.splitext(os.path.split(cfg.filename)[-1])[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册