From bd77dc75fa2406564ef023b43f155c3c5b4d5993 Mon Sep 17 00:00:00 2001 From: Lin Manhui Date: Sun, 25 Jun 2023 15:13:39 +0800 Subject: [PATCH] Toward Devkit Consistency (#8360) * Accommodate UAPI * Fix bugs * Set defaults for use_fd_format * Restore visualize.py * Rename variable * Optimize save_dir * Fix mistakenly update * Add format check --- .../slim/quant/mask_rcnn_r50_fpn_1x_qat.yml | 4 +- configs/slim/quant/yolov3_darknet_qat.yml | 3 ++ deploy/python/infer.py | 54 ++++++++++++++----- deploy/python/keypoint_infer.py | 32 ++++++++--- deploy/python/utils.py | 1 + ppdet/engine/callbacks.py | 6 +-- ppdet/engine/trainer.py | 25 ++++++--- ppdet/utils/checkpoint.py | 15 ++++++ tools/export_model.py | 4 +- 9 files changed, 110 insertions(+), 34 deletions(-) diff --git a/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml b/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml index 7363b4e55..11c6e0663 100644 --- a/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml +++ b/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml @@ -8,9 +8,11 @@ QAT: 'quantizable_layer_type': ['Conv2D', 'Linear']} print_model: True - epoch: 5 +TrainReader: + batch_size: 1 + LearningRate: base_lr: 0.001 schedulers: diff --git a/configs/slim/quant/yolov3_darknet_qat.yml b/configs/slim/quant/yolov3_darknet_qat.yml index 281b53418..42ac4463a 100644 --- a/configs/slim/quant/yolov3_darknet_qat.yml +++ b/configs/slim/quant/yolov3_darknet_qat.yml @@ -10,6 +10,9 @@ QAT: epoch: 50 +TrainReader: + batch_size: 8 + LearningRate: base_lr: 0.0001 schedulers: diff --git a/deploy/python/infer.py b/deploy/python/infer.py index dc0922bb3..703d16223 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -101,8 +101,9 @@ class Detector(object): enable_mkldnn_bfloat16=False, output_dir='output', threshold=0.5, - delete_shuffle_pass=False): - self.pred_config = self.set_config(model_dir) + delete_shuffle_pass=False, + use_fd_format=False): + self.pred_config = self.set_config(model_dir, use_fd_format=use_fd_format) self.predictor, self.config = load_predictor( model_dir, self.pred_config.arch, @@ -125,8 +126,8 @@ class Detector(object): self.output_dir = output_dir self.threshold = threshold - def set_config(self, model_dir): - return PredictConfig(model_dir) + def set_config(self, model_dir, use_fd_format): + return PredictConfig(model_dir, use_fd_format=use_fd_format) def preprocess(self, image_list): preprocess_ops = [] @@ -560,7 +561,8 @@ class DetectorSOLOv2(Detector): enable_mkldnn=False, enable_mkldnn_bfloat16=False, output_dir='./', - threshold=0.5, ): + threshold=0.5, + use_fd_format=False): super(DetectorSOLOv2, self).__init__( model_dir=model_dir, device=device, @@ -574,7 +576,8 @@ class DetectorSOLOv2(Detector): enable_mkldnn=enable_mkldnn, enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, output_dir=output_dir, - threshold=threshold, ) + threshold=threshold, + use_fd_format=use_fd_format) def predict(self, repeats=1, run_benchmark=False): ''' @@ -650,7 +653,8 @@ class DetectorPicoDet(Detector): enable_mkldnn=False, enable_mkldnn_bfloat16=False, output_dir='./', - threshold=0.5, ): + threshold=0.5, + use_fd_format=False): super(DetectorPicoDet, self).__init__( model_dir=model_dir, device=device, @@ -664,7 +668,8 @@ class DetectorPicoDet(Detector): enable_mkldnn=enable_mkldnn, enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, output_dir=output_dir, - threshold=threshold, ) + threshold=threshold, + use_fd_format=use_fd_format) def postprocess(self, inputs, result): # postprocess output of predictor @@ -745,7 +750,8 @@ class DetectorCLRNet(Detector): enable_mkldnn=False, enable_mkldnn_bfloat16=False, output_dir='./', - threshold=0.5, ): + threshold=0.5, + use_fd_format=False): super(DetectorCLRNet, self).__init__( model_dir=model_dir, device=device, @@ -759,7 +765,8 @@ class DetectorCLRNet(Detector): enable_mkldnn=enable_mkldnn, enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, output_dir=output_dir, - threshold=threshold, ) + threshold=threshold, + use_fd_format=use_fd_format) deploy_file = os.path.join(model_dir, 'infer_cfg.yml') with open(deploy_file) as f: @@ -867,9 +874,24 @@ class PredictConfig(): model_dir (str): root path of model.yml """ - def __init__(self, model_dir): + def __init__(self, model_dir, use_fd_format=False): # parsing Yaml config for Preprocess - deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + fd_deploy_file = os.path.join(model_dir, 'inference.yml') + ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + if use_fd_format: + if not os.path.exists(fd_deploy_file) and os.path.exists( + ppdet_deploy_file): + raise RuntimeError( + "Non-FD format model detected. Please set `use_fd_format` to False." + ) + deploy_file = fd_deploy_file + else: + if not os.path.exists(ppdet_deploy_file) and os.path.exists( + fd_deploy_file): + raise RuntimeError( + "FD format model detected. Please set `use_fd_format` to False." + ) + deploy_file = ppdet_deploy_file with open(deploy_file) as f: yml_conf = yaml.safe_load(f) self.check_model(yml_conf) @@ -1121,7 +1143,10 @@ def print_arguments(args): def main(): - deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml') + if FLAGS.use_fd_format: + deploy_file = os.path.join(FLAGS.model_dir, 'inference.yml') + else: + deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml') with open(deploy_file) as f: yml_conf = yaml.safe_load(f) arch = yml_conf['arch'] @@ -1146,7 +1171,8 @@ def main(): enable_mkldnn=FLAGS.enable_mkldnn, enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16, threshold=FLAGS.threshold, - output_dir=FLAGS.output_dir) + output_dir=FLAGS.output_dir, + use_fd_format=FLAGS.use_fd_format) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py index 03695f10e..fc912bf53 100644 --- a/deploy/python/keypoint_infer.py +++ b/deploy/python/keypoint_infer.py @@ -76,7 +76,8 @@ class KeyPointDetector(Detector): enable_mkldnn=False, output_dir='output', threshold=0.5, - use_dark=True): + use_dark=True, + use_fd_format=False): super(KeyPointDetector, self).__init__( model_dir=model_dir, device=device, @@ -89,11 +90,12 @@ class KeyPointDetector(Detector): cpu_threads=cpu_threads, enable_mkldnn=enable_mkldnn, output_dir=output_dir, - threshold=threshold, ) + threshold=threshold, + use_fd_format=use_fd_format) self.use_dark = use_dark - def set_config(self, model_dir): - return PredictConfig_KeyPoint(model_dir) + def set_config(self, model_dir, use_fd_format): + return PredictConfig_KeyPoint(model_dir, use_fd_format=use_fd_format) def get_person_from_rect(self, image, results): # crop the person result from image @@ -302,9 +304,24 @@ class PredictConfig_KeyPoint(): model_dir (str): root path of model.yml """ - def __init__(self, model_dir): + def __init__(self, model_dir, use_fd_format=False): # parsing Yaml config for Preprocess - deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + fd_deploy_file = os.path.join(model_dir, 'inference.yml') + ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + if use_fd_format: + if not os.path.exists(fd_deploy_file) and os.path.exists( + ppdet_deploy_file): + raise RuntimeError( + "Non-FD format model detected. Please set `use_fd_format` to False." + ) + deploy_file = fd_deploy_file + else: + if not os.path.exists(ppdet_deploy_file) and os.path.exists( + fd_deploy_file): + raise RuntimeError( + "FD format model detected. Please set `use_fd_format` to False." + ) + deploy_file = ppdet_deploy_file with open(deploy_file) as f: yml_conf = yaml.safe_load(f) self.check_model(yml_conf) @@ -368,7 +385,8 @@ def main(): enable_mkldnn=FLAGS.enable_mkldnn, threshold=FLAGS.threshold, output_dir=FLAGS.output_dir, - use_dark=FLAGS.use_dark) + use_dark=FLAGS.use_dark, + use_fd_format=FLAGS.use_fd_format) # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/utils.py b/deploy/python/utils.py index b05a5d03d..5fc55352f 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -211,6 +211,7 @@ def argsparser(): type=str, default="shape_range_info.pbtxt", help="Path of a dynamic shape file for tensorrt.") + parser.add_argument("--use_fd_format", action="store_true") return parser diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index eeb2f06de..87dcd61b2 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -160,8 +160,7 @@ class Checkpointer(Callback): def __init__(self, model): super(Checkpointer, self).__init__(model) self.best_ap = -1000. - self.save_dir = os.path.join(self.model.cfg.save_dir, - self.model.cfg.filename) + self.save_dir = self.model.cfg.save_dir if hasattr(self.model.model, 'student_model'): self.weight = self.model.model.student_model else: @@ -323,8 +322,7 @@ class WandbCallback(Callback): raise e self.wandb_params = model.cfg.get('wandb', None) - self.save_dir = os.path.join(self.model.cfg.save_dir, - self.model.cfg.filename) + self.save_dir = self.model.cfg.save_dir if self.wandb_params is None: self.wandb_params = {} for k, v in model.cfg.items(): diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 260dbc9b7..7fcce230d 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -1096,7 +1096,10 @@ class Trainer(object): def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True, - kl_quant=False): + kl_quant=False, + yaml_name=None): + if yaml_name is None: + yaml_name = 'infer_cfg.yml' image_shape = None im_shape = [None, 2] scale_factor = [None, 2] @@ -1147,7 +1150,7 @@ class Trainer(object): # Save infer cfg _dump_infer_config(self.cfg, - os.path.join(save_dir, 'infer_cfg.yml'), image_shape, + os.path.join(save_dir, yaml_name), image_shape, self.model) input_spec = [{ @@ -1203,7 +1206,7 @@ class Trainer(object): return static_model, pruned_input_spec - def export(self, output_dir='output_inference'): + def export(self, output_dir='output_inference', for_fd=False): if hasattr(self.model, 'aux_neck'): self.model.__delattr__('aux_neck') if hasattr(self.model, 'aux_head'): @@ -1211,23 +1214,31 @@ class Trainer(object): self.model.eval() model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0] - save_dir = os.path.join(output_dir, model_name) + if for_fd: + save_dir = output_dir + save_name = 'inference' + yaml_name = 'inference.yml' + else: + save_dir = os.path.join(output_dir, model_name) + save_name = 'model' + yaml_name = None + if not os.path.exists(save_dir): os.makedirs(save_dir) static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec( - save_dir) + save_dir, yaml_name=yaml_name) # dy2st and save model if 'slim' not in self.cfg or 'QAT' not in self.cfg['slim_type']: paddle.jit.save( static_model, - os.path.join(save_dir, 'model'), + os.path.join(save_dir, save_name), input_spec=pruned_input_spec) else: self.cfg.slim.save_quantized_model( self.model, - os.path.join(save_dir, 'model'), + os.path.join(save_dir, save_name), input_spec=pruned_input_spec) logger.info("Export model and saved in {}".format(save_dir)) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 101e46b32..8672c988d 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -300,17 +300,27 @@ def save_model(model, """ if paddle.distributed.get_rank() != 0: return + + save_dir = os.path.normpath(save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir) + + if save_name == "best_model": + best_model_path = os.path.join(save_dir, 'best_model') + if not os.path.exists(best_model_path): + os.makedirs(best_model_path) + save_path = os.path.join(save_dir, save_name) # save model if isinstance(model, nn.Layer): paddle.save(model.state_dict(), save_path + ".pdparams") + best_model = model.state_dict() else: assert isinstance(model, dict), 'model is not a instance of nn.layer or dict' if ema_model is None: paddle.save(model, save_path + ".pdparams") + best_model = model else: assert isinstance(ema_model, dict), ("ema_model is not a instance of dict, " @@ -318,6 +328,11 @@ def save_model(model, # Exchange model and ema_model to save paddle.save(ema_model, save_path + ".pdparams") paddle.save(model, save_path + ".pdema") + best_model = ema_model + + if save_name == 'best_model': + best_model_path = os.path.join(best_model_path, 'model') + paddle.save(best_model, best_model_path + ".pdparams") # save optimizer state_dict = optimizer.state_dict() state_dict['last_epoch'] = last_epoch diff --git a/tools/export_model.py b/tools/export_model.py index f4ffcb500..2a09f7a3d 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -56,6 +56,7 @@ def parse_args(): default=None, type=str, help="Configuration file of slim method.") + parser.add_argument("--for_fd", action='store_true') args = parser.parse_args() return args @@ -76,9 +77,10 @@ def run(FLAGS, cfg): trainer.load_weights(cfg.weights) # export model - trainer.export(FLAGS.output_dir) + trainer.export(FLAGS.output_dir, for_fd=FLAGS.for_fd) if FLAGS.export_serving_model: + assert not FLAGS.for_fd from paddle_serving_client.io import inference_model_to_serving model_name = os.path.splitext(os.path.split(cfg.filename)[-1])[0] -- GitLab