From 6b74c7622cdc6d13c40c657affd9346a221e33c2 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Fri, 28 May 2021 19:25:37 +0800 Subject: [PATCH] fix solov2 deploy (#3200) --- deploy/python/infer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 9552e1bba..c60aae4b4 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -242,17 +242,20 @@ class DetectorSOLOv2(Detector): for i in range(warmup): self.predictor.run() output_names = self.predictor.get_output_names() + np_boxes_num = self.predictor.get_output_handle(output_names[ + 0]).copy_to_cpu() np_label = self.predictor.get_output_handle(output_names[ 1]).copy_to_cpu() np_score = self.predictor.get_output_handle(output_names[ 2]).copy_to_cpu() np_segms = self.predictor.get_output_handle(output_names[ 3]).copy_to_cpu() - self.det_times.inference_time_s.start() for i in range(repeats): self.predictor.run() output_names = self.predictor.get_output_names() + np_boxes_num = self.predictor.get_output_handle(output_names[ + 0]).copy_to_cpu() np_label = self.predictor.get_output_handle(output_names[ 1]).copy_to_cpu() np_score = self.predictor.get_output_handle(output_names[ @@ -262,7 +265,11 @@ class DetectorSOLOv2(Detector): self.det_times.inference_time_s.end(repeats=repeats) self.det_times.img_num += 1 - return dict(segm=np_segms, label=np_label, score=np_score) + return dict( + segm=np_segms, + label=np_label, + score=np_score, + boxes_num=np_boxes_num) def create_inputs(imgs, im_info): @@ -481,6 +488,13 @@ def visualize(image_list, results, labels, output_dir='output/', threshold=0.5): if 'segm' in results: im_results['segm'] = results['segm'][start_idx:start_idx + im_bboxes_num, :] + if 'label' in results: + im_results['label'] = results['label'][start_idx:start_idx + + im_bboxes_num] + if 'score' in results: + im_results['score'] = results['score'][start_idx:start_idx + + im_bboxes_num] + start_idx += im_bboxes_num im = visualize_box_mask( image_file, im_results, labels, threshold=threshold) -- GitLab