From 1b5a1e26d194b0ccf21ae6c3b17d86b720cd3ee5 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Fri, 4 Nov 2022 18:56:07 +0800 Subject: [PATCH] update modnet_resnet50vd_matting (#2100) * add requirements.txt * add init * update format --- .../modnet_resnet50vd_matting/README.md | 30 +- .../modnet_resnet50vd_matting/README_en.md | 26 +- .../modnet_resnet50vd_matting/__init__.py | 0 .../modnet_resnet50vd_matting/module.py | 274 +++++++----------- .../modnet_resnet50vd_matting/processor.py | 25 +- .../requirements.txt | 1 + .../modnet_resnet50vd_matting/resnet.py | 174 ++++------- 7 files changed, 217 insertions(+), 313 deletions(-) create mode 100644 modules/image/matting/modnet_resnet50vd_matting/__init__.py create mode 100644 modules/image/matting/modnet_resnet50vd_matting/requirements.txt diff --git a/modules/image/matting/modnet_resnet50vd_matting/README.md b/modules/image/matting/modnet_resnet50vd_matting/README.md index 03ad69e6..65f0659e 100644 --- a/modules/image/matting/modnet_resnet50vd_matting/README.md +++ b/modules/image/matting/modnet_resnet50vd_matting/README.md @@ -1,7 +1,7 @@ # modnet_resnet50vd_matting |模型名称|modnet_resnet50vd_matting| -| :--- | :---: | +| :--- | :---: | |类别|图像-抠图| |网络|modnet_resnet50vd| |数据集|百度自建数据集| @@ -17,8 +17,8 @@ - 样例结果示例(左为原图,右为效果图):

- - + +

- ### 模型介绍 @@ -26,9 +26,9 @@ - Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。modnet_resnet50vd_matting可生成抠图结果。 - + - 更多详情请参考:[modnet_resnet50vd_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting) - + ## 二、安装 @@ -46,11 +46,11 @@ - ```shell $ hub install modnet_resnet50vd_matting ``` - + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) - + ## 三、模型API预测 - ### 1、命令行预测 @@ -58,9 +58,9 @@ - ```shell $ hub run modnet_resnet50vd_matting --input_path "/PATH/TO/IMAGE" ``` - + - 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) - + - ### 2、预测代码示例 @@ -73,14 +73,14 @@ result = model.predict(["/PATH/TO/IMAGE"]) print(result) ``` - + - ### 3、API - ```python - def predict(self, - image_list, - trimap_list, - visualization, + def predict(self, + image_list, + trimap_list, + visualization, save_path): ``` @@ -97,7 +97,7 @@ - result (list(numpy.ndarray)):模型分割结果: - + ## 四、服务部署 - PaddleHub Serving可以部署人像matting在线服务。 diff --git a/modules/image/matting/modnet_resnet50vd_matting/README_en.md b/modules/image/matting/modnet_resnet50vd_matting/README_en.md index 2a6d4e46..65ff51db 100644 --- a/modules/image/matting/modnet_resnet50vd_matting/README_en.md +++ b/modules/image/matting/modnet_resnet50vd_matting/README_en.md @@ -1,7 +1,7 @@ # modnet_resnet50vd_matting |Module Name|modnet_resnet50vd_matting| -| :--- | :---: | +| :--- | :---: | |Category|Image Matting| |Network|modnet_resnet50vd| |Dataset|Baidu self-built dataset| @@ -17,8 +17,8 @@ - Sample results:

- - + +

- ### Module Introduction @@ -26,9 +26,9 @@ - Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation. - + - For more information, please refer to: [modnet_resnet50vd_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting) - + ## II. Installation @@ -46,11 +46,11 @@ - ```shell $ hub install modnet_resnet50vd_matting ``` - + - In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md) - + ## III. Module API Prediction - ### 1、Command line Prediction @@ -58,7 +58,7 @@ - ```shell $ hub run modnet_resnet50vd_matting --input_path "/PATH/TO/IMAGE" ``` - + - If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst) @@ -76,10 +76,10 @@ - ### 3、API - ```python - def predict(self, - image_list, - trimap_list, - visualization, + def predict(self, + image_list, + trimap_list, + visualization, save_path): ``` @@ -96,7 +96,7 @@ - result (list(numpy.ndarray)):The list of model results. - + ## IV. Server Deployment - PaddleHub Serving can deploy an online service of matting. diff --git a/modules/image/matting/modnet_resnet50vd_matting/__init__.py b/modules/image/matting/modnet_resnet50vd_matting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/image/matting/modnet_resnet50vd_matting/module.py b/modules/image/matting/modnet_resnet50vd_matting/module.py index b57c170a..c9f1076f 100644 --- a/modules/image/matting/modnet_resnet50vd_matting/module.py +++ b/modules/image/matting/modnet_resnet50vd_matting/module.py @@ -11,33 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import argparse import os import time -import argparse -from typing import Callable, Union, List, Tuple +from typing import Callable +from typing import List +from typing import Union -import numpy as np import cv2 -import scipy +import modnet_resnet50vd_matting.processor as P +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddlehub.module.module import moduleinfo -import paddlehub.vision.segmentation_transforms as T -from paddlehub.module.module import moduleinfo, runnable, serving - +import scipy from modnet_resnet50vd_matting.resnet import ResNet50_vd -import modnet_resnet50vd_matting.processor as P + +from paddlehub.module.module import moduleinfo +from paddlehub.module.module import runnable +from paddlehub.module.module import serving -@moduleinfo( - name="modnet_resnet50vd_matting", - type="CV/matting", - author="paddlepaddle", - summary="modnet_resnet50vd_matting is a matting model", - version="1.0.0" -) +@moduleinfo(name="modnet_resnet50vd_matting", + type="CV/matting", + author="paddlepaddle", + summary="modnet_resnet50vd_matting is a matting model", + version="1.0.0") class MODNetResNet50Vd(nn.Layer): """ The MODNet implementation based on PaddlePaddle. @@ -51,14 +50,13 @@ class MODNetResNet50Vd(nn.Layer): pretrained(str, optional): The path of pretrianed model. Defautl: None. """ - def __init__(self, hr_channels:int = 32, pretrained=None): + def __init__(self, hr_channels: int = 32, pretrained=None): super(MODNetResNet50Vd, self).__init__() self.backbone = ResNet50_vd() self.pretrained = pretrained - self.head = MODNetHead( - hr_channels=hr_channels, backbone_channels=self.backbone.feat_channels) + self.head = MODNetHead(hr_channels=hr_channels, backbone_channels=self.backbone.feat_channels) self.blurer = GaussianBlurLayer(1, 3) self.transforms = P.Compose([P.LoadImages(), P.ResizeByShort(), P.ResizeToIntMult(), P.Normalize()]) @@ -72,32 +70,36 @@ class MODNetResNet50Vd(nn.Layer): model_dict = paddle.load(checkpoint) self.set_dict(model_dict) print("load pretrained parameters success") - - def preprocess(self, img: Union[str, np.ndarray] , transforms: Callable, trimap: Union[str, np.ndarray] = None): + + def preprocess(self, img: Union[str, np.ndarray], transforms: Callable, trimap: Union[str, np.ndarray] = None): data = {} data['img'] = img if trimap is not None: data['trimap'] = trimap data['gt_fields'] = ['trimap'] data['trans_info'] = [] - data = self.transforms(data) + data = transforms(data) data['img'] = paddle.to_tensor(data['img']) data['img'] = data['img'].unsqueeze(0) if trimap is not None: data['trimap'] = paddle.to_tensor(data['trimap']) data['trimap'] = data['trimap'].unsqueeze((0, 1)) - return data - + return data + def forward(self, inputs: dict): x = inputs['img'] feat_list = self.backbone(x) y = self.head(inputs=inputs, feat_list=feat_list) return y - - def predict(self, image_list: list, trimap_list: list = None, visualization: bool =False, save_path: str = "modnet_resnet50vd_matting_output"): + + def predict(self, + image_list: list, + trimap_list: list = None, + visualization: bool = False, + save_path: str = "modnet_resnet50vd_matting_output"): self.eval() - result= [] + result = [] with paddle.no_grad(): for i, im_path in enumerate(image_list): trimap = trimap_list[i] if trimap_list is not None else None @@ -116,9 +118,9 @@ class MODNetResNet50Vd(nn.Layer): cv2.imwrite(image_save_path, alpha_pred) return result - + @serving - def serving_method(self, images: list, trimaps:list = None, **kwargs): + def serving_method(self, images: list, trimaps: list = None, **kwargs): """ Run as a service. """ @@ -127,8 +129,8 @@ class MODNetResNet50Vd(nn.Layer): trimap_decoder = [cv2.cvtColor(P.base64_to_cv2(trimap), cv2.COLOR_BGR2GRAY) for trimap in trimaps] else: trimap_decoder = None - - outputs = self.predict(image_list=images_decode, trimap_list= trimap_decoder, **kwargs) + + outputs = self.predict(image_list=images_decode, trimap_list=trimap_decoder, **kwargs) serving_data = [P.cv2_to_base64(outputs[i]) for i in range(len(outputs))] results = {'data': serving_data} @@ -139,11 +141,10 @@ class MODNetResNet50Vd(nn.Layer): """ Run as a command. """ - self.parser = argparse.ArgumentParser( - description="Run the {} module.".format(self.name), - prog='hub run {}'.format(self.name), - usage='%(prog)s', - add_help=True) + self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), + prog='hub run {}'.format(self.name), + usage='%(prog)s', + add_help=True) self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_config_group = self.parser.add_argument_group( title="Config options", description="Run configuration for controlling module behavior, not required.") @@ -155,7 +156,10 @@ class MODNetResNet50Vd(nn.Layer): else: trimap_list = None - results = self.predict(image_list=[args.input_path], trimap_list=trimap_list, save_path=args.output_dir, visualization=args.visualization) + results = self.predict(image_list=[args.input_path], + trimap_list=trimap_list, + save_path=args.output_dir, + visualization=args.visualization) return results @@ -164,10 +168,14 @@ class MODNetResNet50Vd(nn.Layer): Add the command config options. """ - self.arg_config_group.add_argument( - '--output_dir', type=str, default="modnet_resnet50vd_matting_output", help="The directory to save output images.") - self.arg_config_group.add_argument( - '--visualization', type=bool, default=True, help="whether to save output as images.") + self.arg_config_group.add_argument('--output_dir', + type=str, + default="modnet_resnet50vd_matting_output", + help="The directory to save output images.") + self.arg_config_group.add_argument('--visualization', + type=bool, + default=True, + help="whether to save output as images.") def add_module_input_arg(self): """ @@ -175,13 +183,13 @@ class MODNetResNet50Vd(nn.Layer): """ self.arg_input_group.add_argument('--input_path', type=str, help="path to image.") self.arg_input_group.add_argument('--trimap_path', type=str, default=None, help="path to trimap.") - - - + + class MODNetHead(nn.Layer): """ Segmentation head. """ + def __init__(self, hr_channels: int, backbone_channels: int): super().__init__() @@ -196,37 +204,24 @@ class MODNetHead(nn.Layer): return pred_matte - class FusionBranch(nn.Layer): + def __init__(self, hr_channels: int, enc_channels: int): super().__init__() - self.conv_lr4x = Conv2dIBNormRelu( - enc_channels[2], hr_channels, 5, stride=1, padding=2) + self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) - self.conv_f2x = Conv2dIBNormRelu( - 2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) self.conv_f = nn.Sequential( - Conv2dIBNormRelu( - hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), - Conv2dIBNormRelu( - int(hr_channels / 2), - 1, - 1, - stride=1, - padding=0, - with_ibn=False, - with_relu=False)) + Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False)) def forward(self, img: paddle.Tensor, lr8x: paddle.Tensor, hr2x: paddle.Tensor) -> paddle.Tensor: - lr4x = F.interpolate( - lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) lr4x = self.conv_lr4x(lr4x) - lr2x = F.interpolate( - lr4x, scale_factor=2, mode='bilinear', align_corners=False) + lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1)) - f = F.interpolate( - f2x, scale_factor=2, mode='bilinear', align_corners=False) + f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) f = self.conv_f(paddle.concat((f, img), axis=1)) pred_matte = F.sigmoid(f) @@ -238,56 +233,33 @@ class HRBranch(nn.Layer): High Resolution Branch of MODNet """ - def __init__(self, hr_channels: int, enc_channels:int): + def __init__(self, hr_channels: int, enc_channels: int): super().__init__() - self.tohr_enc2x = Conv2dIBNormRelu( - enc_channels[0], hr_channels, 1, stride=1, padding=0) - self.conv_enc2x = Conv2dIBNormRelu( - hr_channels + 3, hr_channels, 3, stride=2, padding=1) + self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) - self.tohr_enc4x = Conv2dIBNormRelu( - enc_channels[1], hr_channels, 1, stride=1, padding=0) - self.conv_enc4x = Conv2dIBNormRelu( - 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) self.conv_hr4x = nn.Sequential( - Conv2dIBNormRelu( - 2 * hr_channels + enc_channels[2] + 3, - 2 * hr_channels, - 3, - stride=1, - padding=1), - Conv2dIBNormRelu( - 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), - Conv2dIBNormRelu( - 2 * hr_channels, hr_channels, 3, stride=1, padding=1)) - - self.conv_hr2x = nn.Sequential( - Conv2dIBNormRelu( - 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), - Conv2dIBNormRelu( - 2 * hr_channels, hr_channels, 3, stride=1, padding=1), - Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), - Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1)) + Conv2dIBNormRelu(2 * hr_channels + enc_channels[2] + 3, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)) + + self.conv_hr2x = nn.Sequential(Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1)) self.conv_hr = nn.Sequential( - Conv2dIBNormRelu( - hr_channels + 3, hr_channels, 3, stride=1, padding=1), - Conv2dIBNormRelu( - hr_channels, - 1, - 1, - stride=1, - padding=0, - with_ibn=False, - with_relu=False)) - - def forward(self, img: paddle.Tensor, enc2x: paddle.Tensor, enc4x: paddle.Tensor, lr8x: paddle.Tensor) -> paddle.Tensor: - img2x = F.interpolate( - img, scale_factor=1 / 2, mode='bilinear', align_corners=False) - img4x = F.interpolate( - img, scale_factor=1 / 4, mode='bilinear', align_corners=False) + Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False)) + + def forward(self, img: paddle.Tensor, enc2x: paddle.Tensor, enc4x: paddle.Tensor, + lr8x: paddle.Tensor) -> paddle.Tensor: + img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False) + img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False) enc2x = self.tohr_enc2x(enc2x) hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1)) @@ -295,12 +267,10 @@ class HRBranch(nn.Layer): enc4x = self.tohr_enc4x(enc4x) hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1)) - lr4x = F.interpolate( - lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1)) - hr2x = F.interpolate( - hr4x, scale_factor=2, mode='bilinear', align_corners=False) + hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1)) pred_detail = None return pred_detail, hr2x @@ -310,31 +280,27 @@ class LRBranch(nn.Layer): """ Low Resolution Branch of MODNet """ + def __init__(self, backbone_channels: int): super().__init__() self.se_block = SEBlock(backbone_channels[4], reduction=4) - self.conv_lr16x = Conv2dIBNormRelu( - backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2) - self.conv_lr8x = Conv2dIBNormRelu( - backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2) - self.conv_lr = Conv2dIBNormRelu( - backbone_channels[2], - 1, - 3, - stride=2, - padding=1, - with_ibn=False, - with_relu=False) + self.conv_lr16x = Conv2dIBNormRelu(backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu(backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu(backbone_channels[2], + 1, + 3, + stride=2, + padding=1, + with_ibn=False, + with_relu=False) def forward(self, feat_list: list) -> List[paddle.Tensor]: enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4] enc32x = self.se_block(enc32x) - lr16x = F.interpolate( - enc32x, scale_factor=2, mode='bilinear', align_corners=False) + lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) lr16x = self.conv_lr16x(lr16x) - lr8x = F.interpolate( - lr16x, scale_factor=2, mode='bilinear', align_corners=False) + lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) lr8x = self.conv_lr8x(lr8x) pred_semantic = None @@ -376,7 +342,7 @@ class Conv2dIBNormRelu(nn.Layer): kernel_size: int, stride: int = 1, padding: int = 0, - dilation:int = 1, + dilation: int = 1, groups: int = 1, bias_attr: paddle.ParamAttr = None, with_ibn: bool = True, @@ -385,15 +351,14 @@ class Conv2dIBNormRelu(nn.Layer): super().__init__() layers = [ - nn.Conv2D( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias_attr=bias_attr) + nn.Conv2D(in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias_attr) ] if with_ibn: @@ -413,20 +378,13 @@ class SEBlock(nn.Layer): SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """ - def __init__(self, num_channels: int, reduction:int = 1): + def __init__(self, num_channels: int, reduction: int = 1): super().__init__() self.pool = nn.AdaptiveAvgPool2D(1) - self.conv = nn.Sequential( - nn.Conv2D( - num_channels, - int(num_channels // reduction), - 1, - bias_attr=False), nn.ReLU(), - nn.Conv2D( - int(num_channels // reduction), - num_channels, - 1, - bias_attr=False), nn.Sigmoid()) + self.conv = nn.Sequential(nn.Conv2D(num_channels, int(num_channels // reduction), 1, + bias_attr=False), nn.ReLU(), + nn.Conv2D(int(num_channels // reduction), num_channels, 1, bias_attr=False), + nn.Sigmoid()) def forward(self, x: paddle.Tensor) -> paddle.Tensor: w = self.pool(x) @@ -454,14 +412,7 @@ class GaussianBlurLayer(nn.Layer): self.op = nn.Sequential( nn.Pad2D(int(self.kernel_size / 2), mode='reflect'), - nn.Conv2D( - channels, - channels, - self.kernel_size, - stride=1, - padding=0, - bias_attr=False, - groups=channels)) + nn.Conv2D(channels, channels, self.kernel_size, stride=1, padding=0, bias_attr=False, groups=channels)) self._init_kernel() self.op[1].weight.stop_gradient = True @@ -479,8 +430,7 @@ class GaussianBlurLayer(nn.Layer): exit() elif not x.shape[1] == self.channels: print('In \'GaussianBlurLayer\', the required channel ({0}) is' - 'not the same as input ({1})\n'.format( - self.channels, x.shape[1])) + 'not the same as input ({1})\n'.format(self.channels, x.shape[1])) exit() return self.op(x) @@ -494,4 +444,4 @@ class GaussianBlurLayer(nn.Layer): kernel = scipy.ndimage.gaussian_filter(n, sigma) kernel = kernel.astype('float32') kernel = kernel[np.newaxis, np.newaxis, :, :] - paddle.assign(kernel, self.op[1].weight) \ No newline at end of file + paddle.assign(kernel, self.op[1].weight) diff --git a/modules/image/matting/modnet_resnet50vd_matting/processor.py b/modules/image/matting/modnet_resnet50vd_matting/processor.py index 3ae79593..10246b6c 100644 --- a/modules/image/matting/modnet_resnet50vd_matting/processor.py +++ b/modules/image/matting/modnet_resnet50vd_matting/processor.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import random import base64 -from typing import Callable, Union, List, Tuple +from typing import Callable +from typing import List +from typing import Tuple +from typing import Union import cv2 import numpy as np import paddle import paddle.nn.functional as F from paddleseg.transforms import functional -from PIL import Image class Compose: @@ -61,6 +61,7 @@ class LoadImages: Args: to_rgb (bool, optional): If converting image to RGB color space. Default: True. """ + def __init__(self, to_rgb: bool = True): self.to_rgb = to_rgb @@ -95,7 +96,7 @@ class ResizeByShort: short_size (int): The target size of short side. """ - def __init__(self, short_size: int =512): + def __init__(self, short_size: int = 512): self.short_size = short_size def __call__(self, data: dict) -> dict: @@ -140,14 +141,13 @@ class Normalize: ValueError: When mean/std is not list or any value in std is 0. """ - def __init__(self, mean: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5), std: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5)): + def __init__(self, + mean: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5), + std: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5)): self.mean = mean self.std = std - if not (isinstance(self.mean, (list, tuple)) - and isinstance(self.std, (list, tuple))): - raise ValueError( - "{}: input type is invalid. It should be list or tuple".format( - self)) + if not (isinstance(self.mean, (list, tuple)) and isinstance(self.std, (list, tuple))): + raise ValueError("{}: input type is invalid. It should be list or tuple".format(self)) from functools import reduce if reduce(lambda x, y: x * y, self.std) == 0: raise ValueError('{}: std is invalid!'.format(self)) @@ -177,6 +177,7 @@ def reverse_transform(alpha: paddle.Tensor, trans_info: List[str]): raise Exception("Unexpected info '{}' in im_info".format(item[0])) return alpha + def save_alpha_pred(alpha: np.ndarray, trimap: np.ndarray = None): """ The value of alpha is range [0, 1], shape should be [h,w] @@ -204,4 +205,4 @@ def base64_to_cv2(b64str: str): data = base64.b64decode(b64str.encode('utf8')) data = np.fromstring(data, np.uint8) data = cv2.imdecode(data, cv2.IMREAD_COLOR) - return data \ No newline at end of file + return data diff --git a/modules/image/matting/modnet_resnet50vd_matting/requirements.txt b/modules/image/matting/modnet_resnet50vd_matting/requirements.txt new file mode 100644 index 00000000..f1870c21 --- /dev/null +++ b/modules/image/matting/modnet_resnet50vd_matting/requirements.txt @@ -0,0 +1 @@ +paddleseg>=2.3.0 diff --git a/modules/image/matting/modnet_resnet50vd_matting/resnet.py b/modules/image/matting/modnet_resnet50vd_matting/resnet.py index 19abe41c..8faf67bd 100644 --- a/modules/image/matting/modnet_resnet50vd_matting/resnet.py +++ b/modules/image/matting/modnet_resnet50vd_matting/resnet.py @@ -11,45 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import paddle import paddle.nn as nn import paddle.nn.functional as F - from paddleseg.models import layers -from paddleseg.utils import utils __all__ = ["ResNet50_vd"] class ConvBNLayer(nn.Layer): """Basic conv bn relu layer.""" - + def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - groups: int = 1, - is_vd_mode: bool = False, - act: str = None, + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + is_vd_mode: bool = False, + act: str = None, ): super(ConvBNLayer, self).__init__() self.is_vd_mode = is_vd_mode - self._pool2d_avg = nn.AvgPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True) - self._conv = nn.Conv2D( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=(kernel_size - 1) // 2 if dilation == 1 else 0, - dilation=dilation, - groups=groups, - bias_attr=False) + self._pool2d_avg = nn.AvgPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2 if dilation == 1 else 0, + dilation=dilation, + groups=groups, + bias_attr=False) self._batch_norm = layers.SyncBatchNorm(out_channels) self._act_op = layers.Activation(act=act) @@ -66,7 +61,7 @@ class ConvBNLayer(nn.Layer): class BottleneckBlock(nn.Layer): """Residual bottleneck block""" - + def __init__(self, in_channels: int, out_channels: int, @@ -76,34 +71,24 @@ class BottleneckBlock(nn.Layer): dilation: int = 1): super(BottleneckBlock, self).__init__() - self.conv0 = ConvBNLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - act='relu') + self.conv0 = ConvBNLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, act='relu') self.dilation = dilation - self.conv1 = ConvBNLayer( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - stride=stride, - act='relu', - dilation=dilation) - self.conv2 = ConvBNLayer( - in_channels=out_channels, - out_channels=out_channels * 4, - kernel_size=1, - act=None) + self.conv1 = ConvBNLayer(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + dilation=dilation) + self.conv2 = ConvBNLayer(in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, act=None) if not shortcut: - self.short = ConvBNLayer( - in_channels=in_channels, - out_channels=out_channels * 4, - kernel_size=1, - stride=1, - is_vd_mode=False if if_first or stride == 1 else True) + self.short = ConvBNLayer(in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first or stride == 1 else True) self.shortcut = shortcut @@ -133,33 +118,23 @@ class BottleneckBlock(nn.Layer): class BasicBlock(nn.Layer): """Basic residual block""" - def __init__(self, - in_channels: int, - out_channels: int, - stride: int, - shortcut: bool = True, - if_first: bool = False): + + def __init__(self, in_channels: int, out_channels: int, stride: int, shortcut: bool = True, if_first: bool = False): super(BasicBlock, self).__init__() self.stride = stride - self.conv0 = ConvBNLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=stride, - act='relu') - self.conv1 = ConvBNLayer( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - act=None) + self.conv0 = ConvBNLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu') + self.conv1 = ConvBNLayer(in_channels=out_channels, out_channels=out_channels, kernel_size=3, act=None) if not shortcut: - self.short = ConvBNLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - is_vd_mode=False if if_first else True) + self.short = ConvBNLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True) self.shortcut = shortcut @@ -212,13 +187,11 @@ class ResNet_vd(nn.Layer): depth = [3, 8, 36, 3] elif layers == 200: depth = [3, 12, 48, 3] - num_channels = [64, 256, 512, 1024 - ] if layers >= 50 else [64, 64, 128, 256] + num_channels = [64, 256, 512, 1024] if layers >= 50 else [64, 64, 128, 256] num_filters = [64, 128, 256, 512] # for channels of four returned stages - self.feat_channels = [c * 4 for c in num_filters - ] if layers >= 50 else num_filters + self.feat_channels = [c * 4 for c in num_filters] if layers >= 50 else num_filters self.feat_channels = [64] + self.feat_channels dilation_dict = None @@ -227,24 +200,9 @@ class ResNet_vd(nn.Layer): elif output_stride == 16: dilation_dict = {3: 2} - self.conv1_1 = ConvBNLayer( - in_channels=input_channels, - out_channels=32, - kernel_size=3, - stride=2, - act='relu') - self.conv1_2 = ConvBNLayer( - in_channels=32, - out_channels=32, - kernel_size=3, - stride=1, - act='relu') - self.conv1_3 = ConvBNLayer( - in_channels=32, - out_channels=64, - kernel_size=3, - stride=1, - act='relu') + self.conv1_1 = ConvBNLayer(in_channels=input_channels, out_channels=32, kernel_size=3, stride=2, act='relu') + self.conv1_2 = ConvBNLayer(in_channels=32, out_channels=32, kernel_size=3, stride=1, act='relu') + self.conv1_3 = ConvBNLayer(in_channels=32, out_channels=64, kernel_size=3, stride=1, act='relu') self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) # self.block_list = [] @@ -264,8 +222,7 @@ class ResNet_vd(nn.Layer): ############################################################################### # Add dilation rate for some segmentation tasks, if dilation_dict is not None. - dilation_rate = dilation_dict[ - block] if dilation_dict and block in dilation_dict else 1 + dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1 # Actually block here is 'stage', and i is 'block' in 'stage' # At the stage 4, expand the the dilation_rate if given multi_grid @@ -275,15 +232,12 @@ class ResNet_vd(nn.Layer): bottleneck_block = self.add_sublayer( 'bb_%d_%d' % (block, i), - BottleneckBlock( - in_channels=num_channels[block] - if i == 0 else num_filters[block] * 4, - out_channels=num_filters[block], - stride=2 if i == 0 and block != 0 - and dilation_rate == 1 else 1, - shortcut=shortcut, - if_first=block == i == 0, - dilation=dilation_rate)) + BottleneckBlock(in_channels=num_channels[block] if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1, + shortcut=shortcut, + if_first=block == i == 0, + dilation=dilation_rate)) block_list.append(bottleneck_block) shortcut = True @@ -296,13 +250,11 @@ class ResNet_vd(nn.Layer): conv_name = "res" + str(block + 2) + chr(97 + i) basic_block = self.add_sublayer( 'bb_%d_%d' % (block, i), - BasicBlock( - in_channels=num_channels[block] - if i == 0 else num_filters[block], - out_channels=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - shortcut=shortcut, - if_first=block == i == 0)) + BasicBlock(in_channels=num_channels[block] if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0)) block_list.append(basic_block) shortcut = True self.stage_list.append(block_list) -- GitLab