diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 7ce83c25a3656a6b9f27cc20885039d28be2263d..ee635c718ae3e3c3ba9f2857850ea20d40f7fe1b 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -553,6 +553,10 @@ class PredictConfig(): self.nms = yml_conf['NMS'] if 'fpn_stride' in yml_conf: self.fpn_stride = yml_conf['fpn_stride'] + if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): + print( + 'The RCNN export model is used for ONNX and it only supports batch_size = 1' + ) self.print_config() def check_model(self, yml_conf): diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 08871ed3a1db09a9a86667fbd83fa5728d9c21c8..8f55f029396f6a5a8d35cc1eb172b5710a4350e4 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -126,7 +126,13 @@ def _dump_infer_config(config, path, image_shape, model): 'metric': config['metric'], 'use_dynamic_shape': use_dynamic_shape }) + export_onnx = config.get('export_onnx', False) + infer_arch = config['architecture'] + if 'RCNN' in infer_arch and export_onnx: + logger.warning( + "Exporting RCNN model to ONNX only support batch_size = 1") + infer_cfg['export_onnx'] = True if infer_arch in MOT_ARCH: if infer_arch == 'DeepSORT':