From efd2b7d75dd17ad235f3dd70e27131d1ee9e4a84 Mon Sep 17 00:00:00 2001 From: Steffy-zxf Date: Tue, 23 Jun 2020 09:39:00 +0800 Subject: [PATCH] fix deeplabv3p_xception65_humanseg postprocess bug --- .../deeplabv3p_xception65_humanseg/module.py | 10 +++++++++- .../deeplabv3p_xception65_humanseg/processor.py | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/module.py b/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/module.py index b6f0f227..f2d59772 100644 --- a/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/module.py +++ b/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/module.py @@ -22,7 +22,7 @@ from deeplabv3p_xception65_humanseg.data_feed import reader author="baidu-vis", author_email="", summary="DeepLabv3+ is a semantic segmentation model.", - version="1.1.0") + version="1.1.1") class DeeplabV3pXception65HumanSeg(hub.Module): def _initialize(self): self.default_pretrained_model_path = os.path.join( @@ -220,3 +220,11 @@ class DeeplabV3pXception65HumanSeg(hub.Module): """ self.arg_input_group.add_argument( '--input_path', type=str, help="path to image.") + + +if __name__ == "__main__": + m = DeeplabV3pXception65HumanSeg() + import cv2 + img = cv2.imread('./meditation.jpg') + res = m.segmentation(images=[img]) + print(res[0]['data']) diff --git a/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/processor.py b/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/processor.py index 21566ae7..ce070412 100644 --- a/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/processor.py +++ b/hub_module/modules/image/semantic_segmentation/deeplabv3p_xception65_humanseg/processor.py @@ -52,8 +52,9 @@ def postprocess(data_out, for logit in data_out: logit = logit[1] * 255 logit = cv2.resize(logit, (org_im_shape[1], org_im_shape[0])) - ret, logit = cv2.threshold(logit, thresh, 0, cv2.THRESH_TOZERO) - logit = 255 * (logit - thresh) / (255 - thresh) + logit -= thresh + logit[logit < 0] = 0 + logit = 255 * logit / (255 - thresh) rgba = np.concatenate((org_im, np.expand_dims(logit, axis=2)), axis=2) if visualization: -- GitLab