From 1930a23507e96c1e68b5b3c657cd21b765c7ee95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E9=B9=AD=E5=85=88=E7=94=9F?= <766529835@qq.com> Date: Wed, 8 Feb 2023 12:07:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/callbacks.py | 4 ++-- yolo.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/utils/callbacks.py b/utils/callbacks.py index a703585..b866781 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -148,9 +148,9 @@ class EvalCallback(): top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] - top_rboxes[:, [0, 2]] *= image_shape[1] - top_rboxes[:, [1, 3]] *= image_shape[0] top_polys = rbox2poly(top_rboxes) + top_polys[:, 0::2] *= image_shape[1] + top_polys[:, 1::2] *= image_shape[0] top_hbbs = poly2hbb(top_polys) top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] top_hbbs = top_hbbs[top_100] diff --git a/yolo.py b/yolo.py index 28bb9dd..d538f58 100644 --- a/yolo.py +++ b/yolo.py @@ -157,9 +157,9 @@ class YOLO(object): top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] - top_rboxes[:, [0, 2]] *= image_shape[1] - top_rboxes[:, [1, 3]] *= image_shape[0] top_polys = rbox2poly(top_rboxes) + top_polys[:, 0::2] *= image_shape[1] + top_polys[:, 1::2] *= image_shape[0] #---------------------------------------------------------# # 设置字体与边框厚度 #---------------------------------------------------------# @@ -378,9 +378,9 @@ class YOLO(object): top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] - top_rboxes[:, [0, 2]] *= image_shape[1] - top_rboxes[:, [1, 3]] *= image_shape[0] top_polys = rbox2poly(top_rboxes) + top_polys[:, 0::2] *= image_shape[1] + top_polys[:, 1::2] *= image_shape[0] top_hbbs = poly2hbb(top_polys) for i, c in list(enumerate(top_label)): predicted_class = self.class_names[int(c)] -- GitLab