From 8dc0677136fd84c2420765d1a18c475df9d87e89 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Mon, 14 Dec 2020 17:44:08 +0800 Subject: [PATCH] Update the U2Net and U2Netp --- .../semantic_segmentation/U2Net/README.md | 3 ++- .../semantic_segmentation/U2Net/processor.py | 25 +++++++++-------- .../semantic_segmentation/U2Netp/README.md | 3 ++- .../semantic_segmentation/U2Netp/processor.py | 27 +++++++++---------- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/modules/image/semantic_segmentation/U2Net/README.md b/modules/image/semantic_segmentation/U2Net/README.md index dd9123c0..0bd5cd94 100644 --- a/modules/image/semantic_segmentation/U2Net/README.md +++ b/modules/image/semantic_segmentation/U2Net/README.md @@ -6,7 +6,8 @@ ## 效果展示 ![](https://ai-studio-static-online.cdn.bcebos.com/4d77bc3a05cf48bba6f67b797978f4cdf10f38288b9645d59393dd85cef58eff) -![](https://ai-studio-static-online.cdn.bcebos.com/d7839c7207024747b32e42e49f7881bd2554d8408ab44e669fb340b50d4e38de) +![](https://ai-studio-static-online.cdn.bcebos.com/865b7b6a262b4ce3bbba4a5c0d973d9eea428bc3e8af4f76a1cdab0c04e3dd33) +![](https://ai-studio-static-online.cdn.bcebos.com/11c9eba8de6d4316b672f10b285245061821f0a744e441f3b80c223881256ca0) ## API ```python diff --git a/modules/image/semantic_segmentation/U2Net/processor.py b/modules/image/semantic_segmentation/U2Net/processor.py index 30891e0a..a6cef1ae 100644 --- a/modules/image/semantic_segmentation/U2Net/processor.py +++ b/modules/image/semantic_segmentation/U2Net/processor.py @@ -33,13 +33,6 @@ class Processor(): def preprocess(self, imgs, batch_size=1, input_size=320): input_datas = [] for image in imgs: - # h, w = image.shape[:2] - # if h > w: - # new_h, new_w = input_size*h/w,input_size - # else: - # new_h, new_w = input_size,input_size*w/h - # image = cv2.resize(image, (int(new_w), int(new_h))) - image = cv2.resize(image, (input_size, input_size)) tmpImg = np.zeros((image.shape[0],image.shape[1],3)) image = image/np.max(image) @@ -78,19 +71,25 @@ class Processor(): for i, image in enumerate(self.imgs): # normalization - pred = 1.0 - outputs[i,0,:,:] + pred = outputs[i,0,:,:] pred = self.normPRED(pred) # convert torch tensor to numpy array - pred = pred.squeeze() - pred = (pred*255).astype(np.uint8) h, w = image.shape[:2] - pred = cv2.resize(pred, (w, h)) + mask = cv2.resize(pred, (w, h)) - results.append(pred) + output_img = (image*mask[..., np.newaxis] + (1-mask[..., np.newaxis])*255).astype(np.uint8) + mask = (mask*255).astype(np.uint8) + if visualization: - cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), pred) + cv2.imwrite(os.path.join(output_dir, 'result_mask_%d.png' % i), mask) + cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), output_img) + results.append({ + 'mask': mask, + 'front': output_img + }) + return results \ No newline at end of file diff --git a/modules/image/semantic_segmentation/U2Netp/README.md b/modules/image/semantic_segmentation/U2Netp/README.md index bed12a13..c0a9be70 100644 --- a/modules/image/semantic_segmentation/U2Netp/README.md +++ b/modules/image/semantic_segmentation/U2Netp/README.md @@ -7,7 +7,8 @@ ## 效果展示 ![](https://ai-studio-static-online.cdn.bcebos.com/4d77bc3a05cf48bba6f67b797978f4cdf10f38288b9645d59393dd85cef58eff) -![](https://ai-studio-static-online.cdn.bcebos.com/d7839c7207024747b32e42e49f7881bd2554d8408ab44e669fb340b50d4e38de) +![](https://ai-studio-static-online.cdn.bcebos.com/865b7b6a262b4ce3bbba4a5c0d973d9eea428bc3e8af4f76a1cdab0c04e3dd33) +![](https://ai-studio-static-online.cdn.bcebos.com/11c9eba8de6d4316b672f10b285245061821f0a744e441f3b80c223881256ca0) ## API ```python diff --git a/modules/image/semantic_segmentation/U2Netp/processor.py b/modules/image/semantic_segmentation/U2Netp/processor.py index d3f11fca..a6cef1ae 100644 --- a/modules/image/semantic_segmentation/U2Netp/processor.py +++ b/modules/image/semantic_segmentation/U2Netp/processor.py @@ -33,13 +33,6 @@ class Processor(): def preprocess(self, imgs, batch_size=1, input_size=320): input_datas = [] for image in imgs: - # h, w = image.shape[:2] - # if h > w: - # new_h, new_w = input_size*h/w,input_size - # else: - # new_h, new_w = input_size,input_size*w/h - # image = cv2.resize(image, (int(new_w), int(new_h))) - image = cv2.resize(image, (input_size, input_size)) tmpImg = np.zeros((image.shape[0],image.shape[1],3)) image = image/np.max(image) @@ -78,19 +71,25 @@ class Processor(): for i, image in enumerate(self.imgs): # normalization - pred = 1.0 - outputs[i,0,:,:] + pred = outputs[i,0,:,:] pred = self.normPRED(pred) # convert torch tensor to numpy array - pred = pred.squeeze() - pred = (pred*255).astype(np.uint8) h, w = image.shape[:2] - pred = cv2.resize(pred, (w, h)) - - results.append(pred) + mask = cv2.resize(pred, (w, h)) + + output_img = (image*mask[..., np.newaxis] + (1-mask[..., np.newaxis])*255).astype(np.uint8) + mask = (mask*255).astype(np.uint8) + if visualization: - cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), pred) + cv2.imwrite(os.path.join(output_dir, 'result_mask_%d.png' % i), mask) + cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), output_img) + results.append({ + 'mask': mask, + 'front': output_img + }) + return results \ No newline at end of file -- GitLab