未验证 提交 6b74c762 编写于 作者: W wangguanzhong 提交者: GitHub

fix solov2 deploy (#3200)

上级 75e3def5
...@@ -242,17 +242,20 @@ class DetectorSOLOv2(Detector): ...@@ -242,17 +242,20 @@ class DetectorSOLOv2(Detector):
for i in range(warmup): for i in range(warmup):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
np_boxes_num = self.predictor.get_output_handle(output_names[
0]).copy_to_cpu()
np_label = self.predictor.get_output_handle(output_names[ np_label = self.predictor.get_output_handle(output_names[
1]).copy_to_cpu() 1]).copy_to_cpu()
np_score = self.predictor.get_output_handle(output_names[ np_score = self.predictor.get_output_handle(output_names[
2]).copy_to_cpu() 2]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[ np_segms = self.predictor.get_output_handle(output_names[
3]).copy_to_cpu() 3]).copy_to_cpu()
self.det_times.inference_time_s.start() self.det_times.inference_time_s.start()
for i in range(repeats): for i in range(repeats):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
np_boxes_num = self.predictor.get_output_handle(output_names[
0]).copy_to_cpu()
np_label = self.predictor.get_output_handle(output_names[ np_label = self.predictor.get_output_handle(output_names[
1]).copy_to_cpu() 1]).copy_to_cpu()
np_score = self.predictor.get_output_handle(output_names[ np_score = self.predictor.get_output_handle(output_names[
...@@ -262,7 +265,11 @@ class DetectorSOLOv2(Detector): ...@@ -262,7 +265,11 @@ class DetectorSOLOv2(Detector):
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1 self.det_times.img_num += 1
return dict(segm=np_segms, label=np_label, score=np_score) return dict(
segm=np_segms,
label=np_label,
score=np_score,
boxes_num=np_boxes_num)
def create_inputs(imgs, im_info): def create_inputs(imgs, im_info):
...@@ -481,6 +488,13 @@ def visualize(image_list, results, labels, output_dir='output/', threshold=0.5): ...@@ -481,6 +488,13 @@ def visualize(image_list, results, labels, output_dir='output/', threshold=0.5):
if 'segm' in results: if 'segm' in results:
im_results['segm'] = results['segm'][start_idx:start_idx + im_results['segm'] = results['segm'][start_idx:start_idx +
im_bboxes_num, :] im_bboxes_num, :]
if 'label' in results:
im_results['label'] = results['label'][start_idx:start_idx +
im_bboxes_num]
if 'score' in results:
im_results['score'] = results['score'][start_idx:start_idx +
im_bboxes_num]
start_idx += im_bboxes_num start_idx += im_bboxes_num
im = visualize_box_mask( im = visualize_box_mask(
image_file, im_results, labels, threshold=threshold) image_file, im_results, labels, threshold=threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册