From ab1c367ff303ecf90f4d8885813c634872b07943 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 16 Sep 2022 11:43:45 +0800 Subject: [PATCH] fix act some bug (#1421) --- example/auto_compression/image_classification/eval.py | 3 ++- example/auto_compression/pytorch_yolo_series/eval.py | 2 +- .../auto_compression/pytorch_yolo_series/post_process.py | 6 +++--- example/auto_compression/pytorch_yolo_series/run.py | 5 ++++- example/full_quantization/image_classification/eval.py | 3 ++- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/example/auto_compression/image_classification/eval.py b/example/auto_compression/image_classification/eval.py index 9cd9b4a3..790f3543 100644 --- a/example/auto_compression/image_classification/eval.py +++ b/example/auto_compression/image_classification/eval.py @@ -98,7 +98,8 @@ def eval(): def main(args): global global_config - global_config = load_slim_config(args.config_path) + all_config = load_slim_config(args.config_path) + global_config = all_config["Global"] global data_dir data_dir = global_config['data_dir'] diff --git a/example/auto_compression/pytorch_yolo_series/eval.py b/example/auto_compression/pytorch_yolo_series/eval.py index 946787d8..8620db34 100644 --- a/example/auto_compression/pytorch_yolo_series/eval.py +++ b/example/auto_compression/pytorch_yolo_series/eval.py @@ -55,7 +55,7 @@ def eval(): score_threshold=0.001, nms_threshold=0.65, multi_label=True, - num_top_k=global_config.get('nms_num_top_k', 30000)) + nms_top_k=global_config.get('nms_num_top_k', 30000)) bboxes_list, bbox_nums_list, image_id_list = [], [], [] with tqdm( diff --git a/example/auto_compression/pytorch_yolo_series/post_process.py b/example/auto_compression/pytorch_yolo_series/post_process.py index 5a3b1ef8..775024ba 100644 --- a/example/auto_compression/pytorch_yolo_series/post_process.py +++ b/example/auto_compression/pytorch_yolo_series/post_process.py @@ -80,7 +80,7 @@ class YOLOPostProcess(object): multi_label(bool): Whether keep multi label in boxes. keep_top_k(int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS step. - num_top_k(int): Maximum number of boxes put into torchvision.ops.nums() + nms_top_k(int): Maximum number of boxes put into nums. """ def __init__(self, @@ -88,12 +88,12 @@ class YOLOPostProcess(object): nms_threshold=0.5, multi_label=False, keep_top_k=300, - num_top_k=30000): + nms_top_k=30000): self.score_threshold = score_threshold self.nms_threshold = nms_threshold self.multi_label = multi_label self.keep_top_k = keep_top_k - self.num_top_k = num_top_k + self.nms_top_k = nms_top_k def _xywh2xyxy(self, x): # Convert from [x, y, w, h] to [x1, y1, x2, y2] diff --git a/example/auto_compression/pytorch_yolo_series/run.py b/example/auto_compression/pytorch_yolo_series/run.py index 94e2c22b..172e3328 100644 --- a/example/auto_compression/pytorch_yolo_series/run.py +++ b/example/auto_compression/pytorch_yolo_series/run.py @@ -71,7 +71,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): feed={test_feed_names[0]: data_all['image']}, fetch_list=test_fetch_list, return_numpy=False) - res = postprocess(np.array(outs[0]), data_all['scale_factor']) + res = postprocess( + np.array(outs[0]), + data_all['scale_factor'], + nms_top_k=global_config.get('nms_num_top_k', 30000)) bboxes_list.append(res['bbox']) bbox_nums_list.append(res['bbox_num']) image_id_list.append(np.array(data_all['im_id'])) diff --git a/example/full_quantization/image_classification/eval.py b/example/full_quantization/image_classification/eval.py index 9cd9b4a3..790f3543 100644 --- a/example/full_quantization/image_classification/eval.py +++ b/example/full_quantization/image_classification/eval.py @@ -98,7 +98,8 @@ def eval(): def main(args): global global_config - global_config = load_slim_config(args.config_path) + all_config = load_slim_config(args.config_path) + global_config = all_config["Global"] global data_dir data_dir = global_config['data_dir'] -- GitLab