From 88cf22f21a3e1fdab06fe8617065ca23b14308fa Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Tue, 28 Jul 2020 13:31:10 +0000 Subject: [PATCH] fix bug in segmentation evaluation --- paddlex/cv/models/deeplabv3p.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py index d5afd46..c3c6ff8 100644 --- a/paddlex/cv/models/deeplabv3p.py +++ b/paddlex/cv/models/deeplabv3p.py @@ -360,18 +360,19 @@ class DeepLabv3p(BaseAPI): pred = pred[0:num_samples] for i in range(num_samples): - one_pred = pred[i].astype('uint8') + one_pred = np.squeeze(pred[i]).astype('uint8') one_label = labels[i] for info in im_info[i][::-1]: if info[0] == 'resize': w, h = info[1][1], info[1][0] - one_pred = cv2.resize(one_pred, (w, h), cv2.INTER_NEAREST) + one_pred = cv2.resize(one_pred, (w, h), + cv2.INTER_NEAREST) elif info[0] == 'padding': w, h = info[1][1], info[1][0] one_pred = one_pred[0:h, 0:w] else: - raise Exception("Unexpected info '{}' in im_info".format( - info[0])) + raise Exception( + "Unexpected info '{}' in im_info".format(info[0])) one_pred = one_pred.astype('int64') one_pred = one_pred[np.newaxis, :, :, np.newaxis] one_label = one_label[np.newaxis, np.newaxis, :, :] -- GitLab