From 47895a938004a84fa794de918cc8b63cc4de49e8 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Tue, 22 Dec 2020 19:24:10 +0800 Subject: [PATCH] Add some module from the Photo2Cartoon project (#1117) * Add some modules from Photo2Cartoon Add some modules from Photo2Cartoon * Add a face seg module Add a face seg module * Update README.md --- .../style_transfer/ID_Photo_GEN/README.md | 48 ++ .../style_transfer/ID_Photo_GEN/module.py | 197 +++++ .../style_transfer/Photo2Cartoon/README.md | 55 ++ .../Photo2Cartoon/model/__init__.py | 2 + .../Photo2Cartoon/model/networks.py | 365 ++++++++ .../style_transfer/Photo2Cartoon/module.py | 231 +++++ .../FCN_HRNet_W18_Face_Seg/README.md | 49 ++ .../FCN_HRNet_W18_Face_Seg/model/__init__.py | 2 + .../FCN_HRNet_W18_Face_Seg/model/fcn.py | 118 +++ .../FCN_HRNet_W18_Face_Seg/model/hrnet.py | 797 ++++++++++++++++++ .../FCN_HRNet_W18_Face_Seg/model/layers.py | 65 ++ .../FCN_HRNet_W18_Face_Seg/module.py | 156 ++++ 12 files changed, 2085 insertions(+) create mode 100644 modules/image/Image_gan/style_transfer/ID_Photo_GEN/README.md create mode 100644 modules/image/Image_gan/style_transfer/ID_Photo_GEN/module.py create mode 100644 modules/image/Image_gan/style_transfer/Photo2Cartoon/README.md create mode 100644 modules/image/Image_gan/style_transfer/Photo2Cartoon/model/__init__.py create mode 100644 modules/image/Image_gan/style_transfer/Photo2Cartoon/model/networks.py create mode 100644 modules/image/Image_gan/style_transfer/Photo2Cartoon/module.py create mode 100644 modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/README.md create mode 100644 modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/__init__.py create mode 100644 modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/fcn.py create mode 100644 modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/hrnet.py create mode 100644 modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/layers.py create mode 100644 modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/module.py diff --git a/modules/image/Image_gan/style_transfer/ID_Photo_GEN/README.md b/modules/image/Image_gan/style_transfer/ID_Photo_GEN/README.md new file mode 100644 index 00000000..aec6e100 --- /dev/null +++ b/modules/image/Image_gan/style_transfer/ID_Photo_GEN/README.md @@ -0,0 +1,48 @@ +## 概述 +* 基于 face_landmark_localization 和 FCN_HRNet_W18_Face_Seg 模型实现的证件照生成模型,一键生成白底、红底和蓝底的人像照片 + +## 效果展示 +![](https://ai-studio-static-online.cdn.bcebos.com/cb82156ad0d44938896c3679534cc50e8e9c170cdd1b4ab8bf34304f69229fda) + +## API +```python +def Photo_GEN( + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False, + use_gpu=False): +``` +证件照生成 API + +**参数** +* images (list[np.ndarray]) : 输入图像数据列表(BGR) +* paths (list[str]) : 输入图像路径列表 +* batch_size (int) : 数据批大小 +* output_dir (str) : 可视化图像输出目录 +* visualization (bool) : 是否可视化 +* use_gpu (bool) : 是否使用 GPU 进行推理 + +**返回** +* results (list[dict{"write":np.ndarray,"blue":np.ndarray,"red":np.ndarray}]): 输出图像数据列表 + +**代码示例** +```python +import cv2 +import paddlehub as hub + +model = hub.Module(name='ID_Photo_GEN') + +result = model.Photo_GEN( + images=[cv2.imread('/PATH/TO/IMAGE')], + paths=None, + batch_size=1, + output_dir='output', + visualization=True, + use_gpu=False) +``` + +## 依赖 +paddlepaddle >= 2.0.0rc0 +paddlehub >= 2.0.0b1 diff --git a/modules/image/Image_gan/style_transfer/ID_Photo_GEN/module.py b/modules/image/Image_gan/style_transfer/ID_Photo_GEN/module.py new file mode 100644 index 00000000..74535b4a --- /dev/null +++ b/modules/image/Image_gan/style_transfer/ID_Photo_GEN/module.py @@ -0,0 +1,197 @@ +import os +import cv2 +import math +import paddle +import numpy as np +import paddle.nn as nn +import paddlehub as hub +from paddlehub.module.module import moduleinfo + +@moduleinfo( + name="ID_Photo_GEN", # 模型名称 + type="CV", # 模型类型 + author="jm12138", # 作者名称 + author_email="jm12138@qq.com", # 作者邮箱 + summary="ID_Photo_GEN", # 模型介绍 + version="1.0.0" # 版本号 +) +class ID_Photo_GEN(nn.Layer): + def __init__(self): + super(ID_Photo_GEN, self).__init__() + # 加载人脸关键点检测模型 + self.face_detector = hub.Module(name="face_landmark_localization") + + # 加载人脸分割模型 + self.seg = hub.Module(name='FCN_HRNet_W18_Face_Seg') + + # 读取数据函数 + @staticmethod + def load_datas(paths, images): + datas = [] + + # 读取数据列表 + if paths is not None: + for im_path in paths: + assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path) + im = cv2.imread(im_path) + datas.append(im) + + if images is not None: + datas = images + + # 返回数据列表 + return datas + + # 数据预处理函数 + def preprocess(self, images, batch_size, use_gpu): + # 获取人脸关键点 + outputs = self.face_detector.keypoint_detection( + images=images, + batch_size=batch_size, + use_gpu=use_gpu) + + crops = [] + for output, image in zip(outputs, images): + for landmarks in output['data']: + landmarks = np.array(landmarks) + + # rotation angle + left_eye_corner = landmarks[36] + right_eye_corner = landmarks[45] + radian = np.arctan((left_eye_corner[1] - right_eye_corner[1]) / (left_eye_corner[0] - right_eye_corner[0])) + + # image size after rotating + height, width, _ = image.shape + cos = math.cos(radian) + sin = math.sin(radian) + new_w = int(width * abs(cos) + height * abs(sin)) + new_h = int(width * abs(sin) + height * abs(cos)) + + # translation + Tx = new_w // 2 - width // 2 + Ty = new_h // 2 - height // 2 + + # affine matrix + M = np.array([[cos, sin, (1 - cos) * width / 2. - sin * height / 2. + Tx], + [-sin, cos, sin * width / 2. + (1 - cos) * height / 2. + Ty]]) + + image = cv2.warpAffine(image, M, (new_w, new_h), borderValue=(255, 255, 255)) + + landmarks = np.concatenate([landmarks, np.ones((landmarks.shape[0], 1))], axis=1) + landmarks = np.dot(M, landmarks.T).T + landmarks_top = np.min(landmarks[:, 1]) + landmarks_bottom = np.max(landmarks[:, 1]) + landmarks_left = np.min(landmarks[:, 0]) + landmarks_right = np.max(landmarks[:, 0]) + + # expand bbox + top = int(landmarks_top - 0.8 * (landmarks_bottom - landmarks_top)) + bottom = int(landmarks_bottom + 0.3 * (landmarks_bottom - landmarks_top)) + left = int(landmarks_left - 0.3 * (landmarks_right - landmarks_left)) + right = int(landmarks_right + 0.3 * (landmarks_right - landmarks_left)) + + # crop + if bottom - top > right - left: + left -= ((bottom - top) - (right - left)) // 2 + right = left + (bottom - top) + else: + top -= ((right - left) - (bottom - top)) // 2 + bottom = top + (right - left) + + image_crop = np.ones((bottom - top + 1, right - left + 1, 3), np.uint8) * 255 + + h, w = image.shape[:2] + left_white = max(0, -left) + left = max(0, left) + right = min(right, w-1) + right_white = left_white + (right-left) + top_white = max(0, -top) + top = max(0, top) + bottom = min(bottom, h-1) + bottom_white = top_white + (bottom - top) + + image_crop[top_white:bottom_white+1, left_white:right_white+1] = image[top:bottom+1, left:right+1].copy() + crops.append(image_crop) + + # 获取人像分割的输出 + results = self.seg.Segmentation(images=crops, batch_size=batch_size) + + faces = [] + masks = [] + + for result in results: + # 提取MASK和输出图像 + face = result['face'] + mask = result['mask'] + + faces.append(face) + masks.append(mask) + + return faces, masks + + # 模型预测函数 + def predict(self, input_datas): + outputs = [] + + for data in input_datas: + # 转换数据为Tensor + data = paddle.to_tensor(data) + + # 模型前向计算 + cartoon = self.net(data) + + outputs.append(cartoon[0].numpy()) + + outputs = np.concatenate(outputs, 0) + + return outputs + + # 结果后处理函数 + @staticmethod + def postprocess(faces, masks, visualization, output_dir): + # 检查输出目录 + if visualization: + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + results = [] + + for face, mask, i in zip(faces, masks, range(len(masks))): + mask = mask[..., np.newaxis] / 255 + write = face * mask + (1 - mask) * 255 + blue = face * mask + (1 - mask) * [255, 0, 0] + red = face * mask + (1 - mask) * [0, 0, 255] + + # 可视化结果保存 + if visualization: + cv2.imwrite(os.path.join(output_dir, 'write_%d.jpg' % i), write) + cv2.imwrite(os.path.join(output_dir, 'blue_%d.jpg' % i), blue) + cv2.imwrite(os.path.join(output_dir, 'red_%d.jpg' % i), red) + + results.append({ + 'write': write, + 'blue': blue, + 'red': red + }) + + return results + + def Photo_GEN( + self, + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False, + use_gpu=False): + + # 获取输入数据 + images = self.load_datas(paths, images) + + # 数据预处理 + faces, masks = self.preprocess(images, batch_size, use_gpu) + + # 结果后处理 + results = self.postprocess(faces, masks, visualization, output_dir) + + return results \ No newline at end of file diff --git a/modules/image/Image_gan/style_transfer/Photo2Cartoon/README.md b/modules/image/Image_gan/style_transfer/Photo2Cartoon/README.md new file mode 100644 index 00000000..75401e7e --- /dev/null +++ b/modules/image/Image_gan/style_transfer/Photo2Cartoon/README.md @@ -0,0 +1,55 @@ +## 概述 +* 本模型封装自[小视科技photo2cartoon项目的paddlepaddle版本](https://github.com/minivision-ai/photo2cartoon-paddle) +* 人像卡通风格渲染的目标是,在保持原图像ID信息和纹理细节的同时,将真实照片转换为卡通风格的非真实感图像。我们的思路是,从大量照片/卡通数据中习得照片到卡通画的映射。一般而言,基于成对数据的pix2pix方法能达到较好的图像转换效果,但本任务的输入输出轮廓并非一一对应。例如卡通风格的眼睛更大、下巴更瘦;且成对的数据绘制难度大、成本较高,因此我们采用unpaired image translation方法来实现。模型结构方面,在U-GAT-IT的基础上,我们在编码器之前和解码器之后各增加了2个hourglass模块,渐进地提升模型特征抽象和重建能力。由于实验数据较为匮乏,为了降低训练难度,我们将数据处理成固定的模式。首先检测图像中的人脸及关键点,根据人脸关键点旋转校正图像,并按统一标准裁剪,再将裁剪后的头像输入人像分割模型(基于PaddleSeg框架训练)去除背景。 + +![](https://ai-studio-static-online.cdn.bcebos.com/8eff9a95bd6741beb3895f38eca39265f22c358c7d114c11b400bbbcd9c4cfc0) + +## 效果展示 +![](https://ai-studio-static-online.cdn.bcebos.com/a4aaedc5ede449e282f0a1c1df05566b62737ddec98246a9b2d5cfeb0f005563) + +## API +```python +def Cartoon_GEN( + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False, + use_gpu=False): +``` +人像卡通化图像生成 API + +**参数** +* images (list[np.ndarray]) : 输入图像数据列表(BGR) +* paths (list[str]) : 输入图像路径列表 +* batch_size (int) : 数据批大小 +* output_dir (str) : 可视化图像输出目录 +* visualization (bool) : 是否可视化 +* use_gpu (bool) : 是否使用 GPU 进行推理 + +**返回** +* results (list[np.ndarray]): 输出图像数据列表 + +**代码示例** +```python +import cv2 +import paddlehub as hub + +model = hub.Module(name='Photo2Cartoon') + +result = model.Cartoon_GEN( + images=[cv2.imread('/PATH/TO/IMAGE')], + paths=None, + batch_size=1, + output_dir='output', + visualization=True, + use_gpu=False) +``` + +## 查看代码 +https://github.com/PaddlePaddle/PaddleSeg +https://github.com/minivision-ai/photo2cartoon-paddle + +## 依赖 +paddlepaddle >= 2.0.0rc0 +paddlehub >= 2.0.0b1 diff --git a/modules/image/Image_gan/style_transfer/Photo2Cartoon/model/__init__.py b/modules/image/Image_gan/style_transfer/Photo2Cartoon/model/__init__.py new file mode 100644 index 00000000..04de5522 --- /dev/null +++ b/modules/image/Image_gan/style_transfer/Photo2Cartoon/model/__init__.py @@ -0,0 +1,2 @@ +from .networks import ResnetGenerator + diff --git a/modules/image/Image_gan/style_transfer/Photo2Cartoon/model/networks.py b/modules/image/Image_gan/style_transfer/Photo2Cartoon/model/networks.py new file mode 100644 index 00000000..e6c73fcc --- /dev/null +++ b/modules/image/Image_gan/style_transfer/Photo2Cartoon/model/networks.py @@ -0,0 +1,365 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class ResnetGenerator(nn.Layer): + def __init__(self, ngf=32, img_size=256, n_blocks=4, light=True): + super(ResnetGenerator, self).__init__() + self.light = light + self.n_blocks = n_blocks + + DownBlock = [] + DownBlock += [ + nn.Pad2D([3, 3, 3, 3], 'reflect'), + nn.Conv2D(3, ngf, kernel_size=7, stride=1, bias_attr=False), + nn.InstanceNorm2D(ngf, weight_attr=False, bias_attr=False), + nn.ReLU() + ] + + DownBlock += [ + HourGlass(ngf, ngf), + HourGlass(ngf, ngf) + ] + + # Down-Sampling + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2 ** i + DownBlock += [ + nn.Pad2D([1, 1, 1, 1], 'reflect'), + nn.Conv2D(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, bias_attr=False), + nn.InstanceNorm2D(ngf*mult*2, weight_attr=False, bias_attr=False), + nn.ReLU() + ] + + # Encoder Bottleneck + mult = 2 ** n_downsampling + for i in range(n_blocks): + setattr(self, 'EncodeBlock'+str(i+1), ResnetBlock(ngf*mult)) + + # Class Activation Map + self.gap_fc = nn.Linear(ngf*mult, 1, bias_attr=False) + self.gmp_fc = nn.Linear(ngf*mult, 1, bias_attr=False) + self.conv1x1 = nn.Conv2D(ngf*mult*2, ngf*mult, kernel_size=1, stride=1) + self.relu = nn.ReLU() + + # Gamma, Beta block + FC = [] + if self.light: + FC += [ + nn.Linear(ngf*mult, ngf*mult, bias_attr=False), + nn.ReLU(), + nn.Linear(ngf*mult, ngf*mult, bias_attr=False), + nn.ReLU() + ] + + else: + FC += [ + nn.Linear(img_size//mult*img_size//mult*ngf*mult, ngf*mult, bias_attr=False), + nn.ReLU(), + nn.Linear(ngf*mult, ngf*mult, bias_attr=False), + nn.ReLU() + ] + + # Decoder Bottleneck + mult = 2 ** n_downsampling + for i in range(n_blocks): + setattr(self, 'DecodeBlock'+str(i + 1), ResnetSoftAdaLINBlock(ngf*mult)) + + # Up-Sampling + UpBlock = [] + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + UpBlock += [ + nn.Upsample(scale_factor=2), + nn.Pad2D([1, 1, 1, 1], 'reflect'), + nn.Conv2D(ngf*mult, ngf*mult//2, kernel_size=3, stride=1, bias_attr=False), + LIN(ngf*mult//2), + nn.ReLU() + ] + + UpBlock += [ + HourGlass(ngf, ngf), + HourGlass(ngf, ngf, False) + ] + + UpBlock += [ + nn.Pad2D([3, 3, 3, 3], 'reflect'), + nn.Conv2D(3, 3, kernel_size=7, stride=1, bias_attr=False), + nn.Tanh() + ] + + self.DownBlock = nn.Sequential(*DownBlock) + self.FC = nn.Sequential(*FC) + self.UpBlock = nn.Sequential(*UpBlock) + + def forward(self, x): + bs = x.shape[0] + + x = self.DownBlock(x) + + content_features = [] + for i in range(self.n_blocks): + x = getattr(self, 'EncodeBlock'+str(i+1))(x) + content_features.append(F.adaptive_avg_pool2d(x, 1).reshape([bs, -1])) + + gap = F.adaptive_avg_pool2d(x, 1) + gap_logit = self.gap_fc(gap.reshape([bs, -1])) + gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0]) + gap = x * gap_weight.unsqueeze(2).unsqueeze(3) + + gmp = F.adaptive_max_pool2d(x, 1) + gmp_logit = self.gmp_fc(gmp.reshape([bs, -1])) + gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0]) + gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) + + cam_logit = paddle.concat([gap_logit, gmp_logit], 1) + x = paddle.concat([gap, gmp], 1) + x = self.relu(self.conv1x1(x)) + + heatmap = paddle.sum(x, axis=1, keepdim=True) + + if self.light: + x_ = F.adaptive_avg_pool2d(x, 1) + style_features = self.FC(x_.reshape([bs, -1])) + else: + style_features = self.FC(x.reshape([bs, -1])) + + for i in range(self.n_blocks): + x = getattr(self, 'DecodeBlock'+str(i+1))(x, content_features[4-i-1], style_features) + + out = self.UpBlock(x) + + return out, cam_logit, heatmap + + +class ConvBlock(nn.Layer): + def __init__(self, dim_in, dim_out): + super(ConvBlock, self).__init__() + self.dim_in = dim_in + self.dim_out = dim_out + + self.conv_block1 = self.__convblock(dim_in, dim_out//2) + self.conv_block2 = self.__convblock(dim_out//2, dim_out//4) + self.conv_block3 = self.__convblock(dim_out//4, dim_out//4) + + if self.dim_in != self.dim_out: + self.conv_skip = nn.Sequential( + nn.InstanceNorm2D(dim_in, weight_attr=False, bias_attr=False), + nn.ReLU(), + nn.Conv2D(dim_in, dim_out, kernel_size=1, stride=1, bias_attr=False) + ) + + @staticmethod + def __convblock(dim_in, dim_out): + return nn.Sequential( + nn.InstanceNorm2D(dim_in, weight_attr=False, bias_attr=False), + nn.ReLU(), + nn.Pad2D([1, 1, 1, 1], 'reflect'), + nn.Conv2D(dim_in, dim_out, kernel_size=3, stride=1, bias_attr=False) + ) + + def forward(self, x): + residual = x + + x1 = self.conv_block1(x) + x2 = self.conv_block2(x1) + x3 = self.conv_block3(x2) + out = paddle.concat([x1, x2, x3], 1) + + if self.dim_in != self.dim_out: + residual = self.conv_skip(residual) + + return residual + out + + +class HourGlassBlock(nn.Layer): + def __init__(self, dim_in): + super(HourGlassBlock, self).__init__() + + self.n_skip = 4 + self.n_block = 9 + + for i in range(self.n_skip): + setattr(self, 'ConvBlockskip'+str(i+1), ConvBlock(dim_in, dim_in)) + + for i in range(self.n_block): + setattr(self, 'ConvBlock'+str(i+1), ConvBlock(dim_in, dim_in)) + + def forward(self, x): + skips = [] + for i in range(self.n_skip): + skips.append(getattr(self, 'ConvBlockskip'+str(i+1))(x)) + x = F.avg_pool2d(x, 2) + x = getattr(self, 'ConvBlock'+str(i+1))(x) + + x = self.ConvBlock5(x) + + for i in range(self.n_skip): + x = getattr(self, 'ConvBlock'+str(i+6))(x) + x = F.upsample(x, scale_factor=2) + x = skips[self.n_skip-i-1] + x + + return x + + +class HourGlass(nn.Layer): + def __init__(self, dim_in, dim_out, use_res=True): + super(HourGlass, self).__init__() + self.use_res = use_res + + self.HG = nn.Sequential( + HourGlassBlock(dim_in), + ConvBlock(dim_out, dim_out), + nn.Conv2D(dim_out, dim_out, kernel_size=1, stride=1, bias_attr=False), + nn.InstanceNorm2D(dim_out, weight_attr=False, bias_attr=False), + nn.ReLU() + ) + + self.Conv1 = nn.Conv2D(dim_out, 3, kernel_size=1, stride=1) + + if self.use_res: + self.Conv2 = nn.Conv2D(dim_out, dim_out, kernel_size=1, stride=1) + self.Conv3 = nn.Conv2D(3, dim_out, kernel_size=1, stride=1) + + def forward(self, x): + ll = self.HG(x) + tmp_out = self.Conv1(ll) + + if self.use_res: + ll = self.Conv2(ll) + tmp_out_ = self.Conv3(tmp_out) + return x + ll + tmp_out_ + + else: + return tmp_out + + +class ResnetBlock(nn.Layer): + def __init__(self, dim, use_bias=False): + super(ResnetBlock, self).__init__() + conv_block = [] + conv_block += [ + nn.Pad2D([1, 1, 1, 1], 'reflect'), + nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias), + nn.InstanceNorm2D(dim, weight_attr=False, bias_attr=False), + nn.ReLU() + ] + + conv_block += [ + nn.Pad2D([1, 1, 1, 1], 'reflect'), + nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias), + nn.InstanceNorm2D(dim, weight_attr=False, bias_attr=False) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class ResnetSoftAdaLINBlock(nn.Layer): + def __init__(self, dim, use_bias=False): + super(ResnetSoftAdaLINBlock, self).__init__() + self.pad1 = nn.Pad2D([1, 1, 1, 1], 'reflect') + self.conv1 = nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias) + self.norm1 = SoftAdaLIN(dim) + self.relu1 = nn.ReLU() + + self.pad2 = nn.Pad2D([1, 1, 1, 1], 'reflect') + self.conv2 = nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias) + self.norm2 = SoftAdaLIN(dim) + + def forward(self, x, content_features, style_features): + out = self.pad1(x) + out = self.conv1(out) + out = self.norm1(out, content_features, style_features) + out = self.relu1(out) + + out = self.pad2(out) + out = self.conv2(out) + out = self.norm2(out, content_features, style_features) + return out + x + + +class SoftAdaLIN(nn.Layer): + def __init__(self, num_features, eps=1e-5): + super(SoftAdaLIN, self).__init__() + self.norm = AdaLIN(num_features, eps) + + self.w_gamma = self.create_parameter([1, num_features], default_initializer=nn.initializer.Constant(0.)) + self.w_beta = self.create_parameter([1, num_features], default_initializer=nn.initializer.Constant(0.)) + + self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features, bias_attr=False), + nn.ReLU(), + nn.Linear(num_features, num_features, bias_attr=False)) + self.c_beta = nn.Sequential(nn.Linear(num_features, num_features, bias_attr=False), + nn.ReLU(), + nn.Linear(num_features, num_features, bias_attr=False)) + self.s_gamma = nn.Linear(num_features, num_features, bias_attr=False) + self.s_beta = nn.Linear(num_features, num_features, bias_attr=False) + + def forward(self, x, content_features, style_features): + content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features) + style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features) + + # w_gamma_ = nn.clip(self.w_gamma, 0, 1) + # w_beta_ = nn.clip(self.w_beta, 0, 1) + + w_gamma_, w_beta_ = self.w_gamma.expand([x.shape[0], -1]), self.w_beta.expand([x.shape[0], -1]) + soft_gamma = (1. - w_gamma_) * style_gamma + w_gamma_ * content_gamma + soft_beta = (1. - w_beta_) * style_beta + w_beta_ * content_beta + + out = self.norm(x, soft_gamma, soft_beta) + return out + + +class AdaLIN(nn.Layer): + def __init__(self, num_features, eps=1e-5): + super(AdaLIN, self).__init__() + self.eps = eps + self.rho = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(0.9)) + + def forward(self, x, gamma, beta): + in_mean, in_var = paddle.mean(x, axis=[2, 3], keepdim=True), paddle.var(x, axis=[2, 3], keepdim=True) + out_in = (x - in_mean) / paddle.sqrt(in_var + self.eps) + ln_mean, ln_var = paddle.mean(x, axis=[1, 2, 3], keepdim=True), paddle.var(x, axis=[1, 2, 3], keepdim=True) + out_ln = (x - ln_mean) / paddle.sqrt(ln_var + self.eps) + out = self.rho.expand([x.shape[0], -1, -1, -1]) * out_in + \ + (1-self.rho.expand([x.shape[0], -1, -1, -1])) * out_ln + out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) + + return out + + +class LIN(nn.Layer): + def __init__(self, num_features, eps=1e-5): + super(LIN, self).__init__() + self.eps = eps + self.rho = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(0.)) + self.gamma = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(1.)) + self.beta = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(0.)) + + def forward(self, x): + in_mean, in_var = paddle.mean(x, axis=[2, 3], keepdim=True), paddle.var(x, axis=[2, 3], keepdim=True) + out_in = (x - in_mean) / paddle.sqrt(in_var + self.eps) + ln_mean, ln_var = paddle.mean(x, axis=[1, 2, 3], keepdim=True), paddle.var(x, axis=[1, 2, 3], keepdim=True) + out_ln = (x - ln_mean) / paddle.sqrt(ln_var + self.eps) + out = self.rho.expand([x.shape[0], -1, -1, -1]) * out_in + \ + (1-self.rho.expand([x.shape[0], -1, -1, -1])) * out_ln + out = out * self.gamma.expand([x.shape[0], -1, -1, -1]) + self.beta.expand([x.shape[0], -1, -1, -1]) + + return out + + +if __name__ == '__main__': + #d = Discriminator(3) + # paddle.summary(d, (4, 3, 256, 256)) + #out, cam_logit, heatmap = d(paddle.ones([4, 3, 256, 256])) + #print(out.shape, cam_logit.shape, heatmap.shape) + + g = ResnetGenerator(ngf=32, img_size=256, light=True) + out, cam_logit, heatmap = g(paddle.ones([4, 3, 256, 256])) + print(out.shape, cam_logit.shape, heatmap.shape) diff --git a/modules/image/Image_gan/style_transfer/Photo2Cartoon/module.py b/modules/image/Image_gan/style_transfer/Photo2Cartoon/module.py new file mode 100644 index 00000000..feabf1ca --- /dev/null +++ b/modules/image/Image_gan/style_transfer/Photo2Cartoon/module.py @@ -0,0 +1,231 @@ +import os +import cv2 +import math +import paddle +import numpy as np +import paddle.nn as nn +import paddlehub as hub +from Photo2Cartoon.model import ResnetGenerator +from paddlehub.module.module import moduleinfo + +@moduleinfo( + name="Photo2Cartoon", # 模型名称 + type="CV", # 模型类型 + author="jm12138", # 作者名称 + author_email="jm12138@qq.com", # 作者邮箱 + summary="Photo2Cartoon", # 模型介绍 + version="1.0.0" # 版本号 +) +class Photo2Cartoon(nn.Layer): + def __init__(self): + super(Photo2Cartoon, self).__init__() + # 加载人脸关键点检测模型 + self.face_detector = hub.Module(name="face_landmark_localization") + + # 加载人脸分割模型 + self.seg = hub.Module(name='FCN_HRNet_W18_Face_Seg') + + # 加载人脸动漫化模型 + self.net = ResnetGenerator(ngf=32, img_size=256, light=True) + + # 加载人脸动漫化模型参数 + state_dict = paddle.load(os.path.join(self.directory, 'photo2cartoon_weights.pdparams')) + self.net.set_state_dict(state_dict['genA2B']) + + # 将人脸动漫化模型设为评估模式 + self.net.eval() + + # 读取数据函数 + @staticmethod + def load_datas(paths, images): + datas = [] + + # 读取数据列表 + if paths is not None: + for im_path in paths: + assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path) + im = cv2.imread(im_path) + datas.append(im) + + if images is not None: + datas = images + + # 返回数据列表 + return datas + + # 数据预处理函数 + def preprocess(self, images, batch_size, use_gpu): + # 获取人脸关键点 + outputs = self.face_detector.keypoint_detection( + images=images, + batch_size=batch_size, + use_gpu=use_gpu) + + crops = [] + for output, image in zip(outputs, images): + for landmarks in output['data']: + landmarks = np.array(landmarks) + + # rotation angle + left_eye_corner = landmarks[36] + right_eye_corner = landmarks[45] + radian = np.arctan((left_eye_corner[1] - right_eye_corner[1]) / (left_eye_corner[0] - right_eye_corner[0])) + + # image size after rotating + height, width, _ = image.shape + cos = math.cos(radian) + sin = math.sin(radian) + new_w = int(width * abs(cos) + height * abs(sin)) + new_h = int(width * abs(sin) + height * abs(cos)) + + # translation + Tx = new_w // 2 - width // 2 + Ty = new_h // 2 - height // 2 + + # affine matrix + M = np.array([[cos, sin, (1 - cos) * width / 2. - sin * height / 2. + Tx], + [-sin, cos, sin * width / 2. + (1 - cos) * height / 2. + Ty]]) + + image = cv2.warpAffine(image, M, (new_w, new_h), borderValue=(255, 255, 255)) + + landmarks = np.concatenate([landmarks, np.ones((landmarks.shape[0], 1))], axis=1) + landmarks = np.dot(M, landmarks.T).T + landmarks_top = np.min(landmarks[:, 1]) + landmarks_bottom = np.max(landmarks[:, 1]) + landmarks_left = np.min(landmarks[:, 0]) + landmarks_right = np.max(landmarks[:, 0]) + + # expand bbox + top = int(landmarks_top - 0.8 * (landmarks_bottom - landmarks_top)) + bottom = int(landmarks_bottom + 0.3 * (landmarks_bottom - landmarks_top)) + left = int(landmarks_left - 0.3 * (landmarks_right - landmarks_left)) + right = int(landmarks_right + 0.3 * (landmarks_right - landmarks_left)) + + # crop + if bottom - top > right - left: + left -= ((bottom - top) - (right - left)) // 2 + right = left + (bottom - top) + else: + top -= ((right - left) - (bottom - top)) // 2 + bottom = top + (right - left) + + image_crop = np.ones((bottom - top + 1, right - left + 1, 3), np.uint8) * 255 + + h, w = image.shape[:2] + left_white = max(0, -left) + left = max(0, left) + right = min(right, w-1) + right_white = left_white + (right-left) + top_white = max(0, -top) + top = max(0, top) + bottom = min(bottom, h-1) + bottom_white = top_white + (bottom - top) + + image_crop[top_white:bottom_white+1, left_white:right_white+1] = image[top:bottom+1, left:right+1].copy() + crops.append(image_crop) + + # 获取人像分割的输出 + results = self.seg.Segmentation(images=crops, batch_size=batch_size) + + faces = [] + masks = [] + + for result in results: + # 提取MASK和输出图像 + face = result['face'] + mask = result['mask'] + + # 图像格式转换 + face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) + + # 图像拼接 + face_rgba = np.dstack((face, mask)) + + # 图像缩放 + face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA) + + # 拆分图像 + face = face_rgba[:, :, :3].copy() + mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255. + + # 数据格式转换 + face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32) + + faces.append(face) + masks.append(mask) + + input_datas = np.concatenate(faces, 0) + + # 切分数据 + datas_num = input_datas.shape[0] + split_num = datas_num//batch_size+1 if datas_num%batch_size!=0 else datas_num//batch_size + input_datas = np.array_split(input_datas, split_num) + + return input_datas, masks + + # 模型预测函数 + def predict(self, input_datas): + outputs = [] + + for data in input_datas: + # 转换数据为Tensor + data = paddle.to_tensor(data) + + # 模型前向计算 + cartoon = self.net(data) + + outputs.append(cartoon[0].numpy()) + + outputs = np.concatenate(outputs, 0) + + return outputs + + # 结果后处理函数 + @staticmethod + def postprocess(outputs, masks, visualization, output_dir): + # 检查输出目录 + if visualization: + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + cartoons = [] + + for cartoon, mask, i in zip(outputs, masks, range(len(masks))): + # 格式转换 + cartoon = np.transpose(cartoon, (1, 2, 0)) + cartoon = (cartoon + 1) * 127.5 + + # 计算输出图像 + cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8) + cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR) + + # 可视化结果保存 + if visualization: + cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), cartoon) + + cartoons.append(cartoon) + + return cartoons + + def Cartoon_GEN( + self, + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False, + use_gpu=False): + + # 获取输入数据 + images = self.load_datas(paths, images) + + # 数据预处理 + input_datas, masks = self.preprocess(images, batch_size, use_gpu) + + # 模型预测 + outputs = self.predict(input_datas) + + # 结果后处理 + cartoons = self.postprocess(outputs, masks, visualization, output_dir) + + return cartoons \ No newline at end of file diff --git a/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/README.md b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/README.md new file mode 100644 index 00000000..cac1fd7a --- /dev/null +++ b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/README.md @@ -0,0 +1,49 @@ +## 概述 +* 基于 FCN_HRNet_W18 模型实现的人像分割模型 + +## 效果展示 +![](https://ai-studio-static-online.cdn.bcebos.com/88155299a7534f1084f8467a4d6db7871dc4729627d3471c9129d316dc4ff9bc) + +## API +```python +def Segmentation( + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False): +``` +人像分割 API + +**参数** +* images (list[np.ndarray]) : 输入图像数据列表(BGR) +* paths (list[str]) : 输入图像路径列表 +* batch_size (int) : 数据批大小 +* output_dir (str) : 可视化图像输出目录 +* visualization (bool) : 是否可视化 + +**返回** +* results (list[dict{"mask":np.ndarray,"face":np.ndarray}]): 输出图像数据列表 + +**代码示例** +```python +import cv2 +import paddlehub as hub + +model = hub.Module(name='FCN_HRNet_W18_Face_Seg') + +result = model.Segmentation( + images=[cv2.imread('/PATH/TO/IMAGE')], + paths=None, + batch_size=1, + output_dir='output', + visualization=True) +``` + +## 查看代码 +https://github.com/PaddlePaddle/PaddleSeg +https://github.com/minivision-ai/photo2cartoon-paddle + +## 依赖 +paddlepaddle >= 2.0.0rc0 +paddlehub >= 2.0.0b1 diff --git a/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/__init__.py b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/__init__.py new file mode 100644 index 00000000..0af5841d --- /dev/null +++ b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/__init__.py @@ -0,0 +1,2 @@ +from .hrnet import HRNet_W18 +from .fcn import FCN diff --git a/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/fcn.py b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/fcn.py new file mode 100644 index 00000000..c2af584c --- /dev/null +++ b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/fcn.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn +import paddle.nn.functional as F + +from .layers import ConvBNReLU + + +class FCN(nn.Layer): + """ + A simple implementation for FCN based on PaddlePaddle. + + The original article refers to + Evan Shelhamer, et, al. "Fully Convolutional Networks for Semantic Segmentation" + (https://arxiv.org/abs/1411.4038). + + Args: + num_classes (int): The unique number of target classes. + backbone (paddle.nn.Layer): Backbone networks. + backbone_indices (tuple, optional): The values in the tuple indicate the indices of output of backbone. + Default: (-1, ). + channels (int, optional): The channels between conv layer and the last layer of FCNHead. + If None, it will be the number of channels of input features. Default: None. + align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature + is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + pretrained (str, optional): The path or url of pretrained model. Default: None + """ + + def __init__(self, + num_classes, + backbone, + backbone_indices=(-1, ), + channels=None, + align_corners=False, + pretrained=None): + super(FCN, self).__init__() + + self.backbone = backbone + backbone_channels = [ + backbone.feat_channels[i] for i in backbone_indices + ] + + self.head = FCNHead(num_classes, backbone_indices, backbone_channels, + channels) + + self.align_corners = align_corners + self.pretrained = pretrained + + def forward(self, x): + feat_list = self.backbone(x) + logit_list = self.head(feat_list) + return [ + F.interpolate( + logit, + x.shape[2:], + mode='bilinear', + align_corners=self.align_corners) for logit in logit_list + ] + + +class FCNHead(nn.Layer): + """ + A simple implementation for FCNHead based on PaddlePaddle + + Args: + num_classes (int): The unique number of target classes. + backbone_indices (tuple, optional): The values in the tuple indicate the indices of output of backbone. + Default: (-1, ). + channels (int, optional): The channels between conv layer and the last layer of FCNHead. + If None, it will be the number of channels of input features. Default: None. + pretrained (str, optional): The path of pretrained model. Default: None + """ + + def __init__(self, + num_classes, + backbone_indices=(-1, ), + backbone_channels=(270, ), + channels=None): + super(FCNHead, self).__init__() + + self.num_classes = num_classes + self.backbone_indices = backbone_indices + if channels is None: + channels = backbone_channels[0] + + self.conv_1 = ConvBNReLU( + in_channels=backbone_channels[0], + out_channels=channels, + kernel_size=1, + padding='same', + stride=1) + self.cls = nn.Conv2D( + in_channels=channels, + out_channels=self.num_classes, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, feat_list): + logit_list = [] + x = feat_list[self.backbone_indices[0]] + x = self.conv_1(x) + logit = self.cls(x) + logit_list.append(logit) + return logit_list + diff --git a/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/hrnet.py b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/hrnet.py new file mode 100644 index 00000000..9973b1ad --- /dev/null +++ b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/hrnet.py @@ -0,0 +1,797 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .layers import ConvBNReLU, ConvBN + +__all__ = [ + "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", + "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64" +] + + +class HRNet(nn.Layer): + """ + The HRNet implementation based on PaddlePaddle. + + The original article refers to + Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition" + (https://arxiv.org/pdf/1908.07919.pdf). + + Args: + pretrained (str): The path of pretrained model. + stage1_num_modules (int): Number of modules for stage1. Default 1. + stage1_num_blocks (list): Number of blocks per module for stage1. Default [4]. + stage1_num_channels (list): Number of channels per branch for stage1. Default [64]. + stage2_num_modules (int): Number of modules for stage2. Default 1. + stage2_num_blocks (list): Number of blocks per module for stage2. Default [4, 4] + stage2_num_channels (list): Number of channels per branch for stage2. Default [18, 36]. + stage3_num_modules (int): Number of modules for stage3. Default 4. + stage3_num_blocks (list): Number of blocks per module for stage3. Default [4, 4, 4] + stage3_num_channels (list): Number of channels per branch for stage3. Default [18, 36, 72]. + stage4_num_modules (int): Number of modules for stage4. Default 3. + stage4_num_blocks (list): Number of blocks per module for stage4. Default [4, 4, 4, 4] + stage4_num_channels (list): Number of channels per branch for stage4. Default [18, 36, 72. 144]. + has_se (bool): Whether to use Squeeze-and-Excitation module. Default False. + align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, + e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. + """ + + def __init__(self, + pretrained=None, + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[18, 36], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[18, 36, 72], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[18, 36, 72, 144], + has_se=False, + align_corners=False): + super(HRNet, self).__init__() + self.pretrained = pretrained + self.stage1_num_modules = stage1_num_modules + self.stage1_num_blocks = stage1_num_blocks + self.stage1_num_channels = stage1_num_channels + self.stage2_num_modules = stage2_num_modules + self.stage2_num_blocks = stage2_num_blocks + self.stage2_num_channels = stage2_num_channels + self.stage3_num_modules = stage3_num_modules + self.stage3_num_blocks = stage3_num_blocks + self.stage3_num_channels = stage3_num_channels + self.stage4_num_modules = stage4_num_modules + self.stage4_num_blocks = stage4_num_blocks + self.stage4_num_channels = stage4_num_channels + self.has_se = has_se + self.align_corners = align_corners + self.feat_channels = [sum(stage4_num_channels)] + + self.conv_layer1_1 = ConvBNReLU( + in_channels=3, + out_channels=64, + kernel_size=3, + stride=2, + padding='same', + bias_attr=False) + + self.conv_layer1_2 = ConvBNReLU( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=2, + padding='same', + bias_attr=False) + + self.la1 = Layer1( + num_channels=64, + num_blocks=self.stage1_num_blocks[0], + num_filters=self.stage1_num_channels[0], + has_se=has_se, + name="layer2") + + self.tr1 = TransitionLayer( + in_channels=[self.stage1_num_channels[0] * 4], + out_channels=self.stage2_num_channels, + name="tr1") + + self.st2 = Stage( + num_channels=self.stage2_num_channels, + num_modules=self.stage2_num_modules, + num_blocks=self.stage2_num_blocks, + num_filters=self.stage2_num_channels, + has_se=self.has_se, + name="st2", + align_corners=align_corners) + + self.tr2 = TransitionLayer( + in_channels=self.stage2_num_channels, + out_channels=self.stage3_num_channels, + name="tr2") + self.st3 = Stage( + num_channels=self.stage3_num_channels, + num_modules=self.stage3_num_modules, + num_blocks=self.stage3_num_blocks, + num_filters=self.stage3_num_channels, + has_se=self.has_se, + name="st3", + align_corners=align_corners) + + self.tr3 = TransitionLayer( + in_channels=self.stage3_num_channels, + out_channels=self.stage4_num_channels, + name="tr3") + self.st4 = Stage( + num_channels=self.stage4_num_channels, + num_modules=self.stage4_num_modules, + num_blocks=self.stage4_num_blocks, + num_filters=self.stage4_num_channels, + has_se=self.has_se, + name="st4", + align_corners=align_corners) + + def forward(self, x): + conv1 = self.conv_layer1_1(x) + conv2 = self.conv_layer1_2(conv1) + + la1 = self.la1(conv2) + + tr1 = self.tr1([la1]) + st2 = self.st2(tr1) + + tr2 = self.tr2(st2) + st3 = self.st3(tr2) + + tr3 = self.tr3(st3) + st4 = self.st4(tr3) + + x0_h, x0_w = st4[0].shape[2:] + x1 = F.interpolate( + st4[1], (x0_h, x0_w), + mode='bilinear', + align_corners=self.align_corners) + x2 = F.interpolate( + st4[2], (x0_h, x0_w), + mode='bilinear', + align_corners=self.align_corners) + x3 = F.interpolate( + st4[3], (x0_h, x0_w), + mode='bilinear', + align_corners=self.align_corners) + x = paddle.concat([st4[0], x1, x2, x3], axis=1) + + return [x] + + +class Layer1(nn.Layer): + def __init__(self, + num_channels, + num_filters, + num_blocks, + has_se=False, + name=None): + super(Layer1, self).__init__() + + self.bottleneck_block_list = [] + + for i in range(num_blocks): + bottleneck_block = self.add_sublayer( + "bb_{}_{}".format(name, i + 1), + BottleneckBlock( + num_channels=num_channels if i == 0 else num_filters * 4, + num_filters=num_filters, + has_se=has_se, + stride=1, + downsample=True if i == 0 else False, + name=name + '_' + str(i + 1))) + self.bottleneck_block_list.append(bottleneck_block) + + def forward(self, x): + conv = x + for block_func in self.bottleneck_block_list: + conv = block_func(conv) + return conv + + +class TransitionLayer(nn.Layer): + def __init__(self, in_channels, out_channels, name=None): + super(TransitionLayer, self).__init__() + + num_in = len(in_channels) + num_out = len(out_channels) + self.conv_bn_func_list = [] + for i in range(num_out): + residual = None + if i < num_in: + if in_channels[i] != out_channels[i]: + residual = self.add_sublayer( + "transition_{}_layer_{}".format(name, i + 1), + ConvBNReLU( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=3, + padding='same', + bias_attr=False)) + else: + residual = self.add_sublayer( + "transition_{}_layer_{}".format(name, i + 1), + ConvBNReLU( + in_channels=in_channels[-1], + out_channels=out_channels[i], + kernel_size=3, + stride=2, + padding='same', + bias_attr=False)) + self.conv_bn_func_list.append(residual) + + def forward(self, x): + outs = [] + for idx, conv_bn_func in enumerate(self.conv_bn_func_list): + if conv_bn_func is None: + outs.append(x[idx]) + else: + if idx < len(x): + outs.append(conv_bn_func(x[idx])) + else: + outs.append(conv_bn_func(x[-1])) + return outs + + +class Branches(nn.Layer): + def __init__(self, + num_blocks, + in_channels, + out_channels, + has_se=False, + name=None): + super(Branches, self).__init__() + + self.basic_block_list = [] + + for i in range(len(out_channels)): + self.basic_block_list.append([]) + for j in range(num_blocks[i]): + in_ch = in_channels[i] if j == 0 else out_channels[i] + basic_block_func = self.add_sublayer( + "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1), + BasicBlock( + num_channels=in_ch, + num_filters=out_channels[i], + has_se=has_se, + name=name + '_branch_layer_' + str(i + 1) + '_' + + str(j + 1))) + self.basic_block_list[i].append(basic_block_func) + + def forward(self, x): + outs = [] + for idx, input in enumerate(x): + conv = input + for basic_block_func in self.basic_block_list[idx]: + conv = basic_block_func(conv) + outs.append(conv) + return outs + + +class BottleneckBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + has_se, + stride=1, + downsample=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.has_se = has_se + self.downsample = downsample + + self.conv1 = ConvBNReLU( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=1, + padding='same', + bias_attr=False) + + self.conv2 = ConvBNReLU( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=3, + stride=stride, + padding='same', + bias_attr=False) + + self.conv3 = ConvBN( + in_channels=num_filters, + out_channels=num_filters * 4, + kernel_size=1, + padding='same', + bias_attr=False) + + if self.downsample: + self.conv_down = ConvBN( + in_channels=num_channels, + out_channels=num_filters * 4, + kernel_size=1, + padding='same', + bias_attr=False) + + if self.has_se: + self.se = SELayer( + num_channels=num_filters * 4, + num_filters=num_filters * 4, + reduction_ratio=16, + name=name + '_fc') + + def forward(self, x): + residual = x + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + if self.downsample: + residual = self.conv_down(x) + + if self.has_se: + conv3 = self.se(conv3) + + y = conv3 + residual + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride=1, + has_se=False, + downsample=False, + name=None): + super(BasicBlock, self).__init__() + + self.has_se = has_se + self.downsample = downsample + + self.conv1 = ConvBNReLU( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=3, + stride=stride, + padding='same', + bias_attr=False) + self.conv2 = ConvBN( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=3, + padding='same', + bias_attr=False) + + if self.downsample: + self.conv_down = ConvBNReLU( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=1, + padding='same', + bias_attr=False) + + if self.has_se: + self.se = SELayer( + num_channels=num_filters, + num_filters=num_filters, + reduction_ratio=16, + name=name + '_fc') + + def forward(self, x): + residual = x + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + + if self.downsample: + residual = self.conv_down(x) + + if self.has_se: + conv2 = self.se(conv2) + + y = conv2 + residual + y = F.relu(y) + return y + + +class SELayer(nn.Layer): + def __init__(self, num_channels, num_filters, reduction_ratio, name=None): + super(SELayer, self).__init__() + + self.pool2d_gap = nn.AdaptiveAvgPool2D(1) + + self._num_channels = num_channels + + med_ch = int(num_channels / reduction_ratio) + stdv = 1.0 / math.sqrt(num_channels * 1.0) + self.squeeze = nn.Linear( + num_channels, + med_ch, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Uniform(-stdv, stdv))) + + stdv = 1.0 / math.sqrt(med_ch * 1.0) + self.excitation = nn.Linear( + med_ch, + num_filters, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Uniform(-stdv, stdv))) + + def forward(self, x): + pool = self.pool2d_gap(x) + pool = paddle.reshape(pool, shape=[-1, self._num_channels]) + squeeze = self.squeeze(pool) + squeeze = F.relu(squeeze) + excitation = self.excitation(squeeze) + excitation = F.sigmoid(excitation) + excitation = paddle.reshape( + excitation, shape=[-1, self._num_channels, 1, 1]) + out = x * excitation + return out + + +class Stage(nn.Layer): + def __init__(self, + num_channels, + num_modules, + num_blocks, + num_filters, + has_se=False, + multi_scale_output=True, + name=None, + align_corners=False): + super(Stage, self).__init__() + + self._num_modules = num_modules + + self.stage_func_list = [] + for i in range(num_modules): + if i == num_modules - 1 and not multi_scale_output: + stage_func = self.add_sublayer( + "stage_{}_{}".format(name, i + 1), + HighResolutionModule( + num_channels=num_channels, + num_blocks=num_blocks, + num_filters=num_filters, + has_se=has_se, + multi_scale_output=False, + name=name + '_' + str(i + 1), + align_corners=align_corners)) + else: + stage_func = self.add_sublayer( + "stage_{}_{}".format(name, i + 1), + HighResolutionModule( + num_channels=num_channels, + num_blocks=num_blocks, + num_filters=num_filters, + has_se=has_se, + name=name + '_' + str(i + 1), + align_corners=align_corners)) + + self.stage_func_list.append(stage_func) + + def forward(self, x): + out = x + for idx in range(self._num_modules): + out = self.stage_func_list[idx](out) + return out + + +class HighResolutionModule(nn.Layer): + def __init__(self, + num_channels, + num_blocks, + num_filters, + has_se=False, + multi_scale_output=True, + name=None, + align_corners=False): + super(HighResolutionModule, self).__init__() + + self.branches_func = Branches( + num_blocks=num_blocks, + in_channels=num_channels, + out_channels=num_filters, + has_se=has_se, + name=name) + + self.fuse_func = FuseLayers( + in_channels=num_filters, + out_channels=num_filters, + multi_scale_output=multi_scale_output, + name=name, + align_corners=align_corners) + + def forward(self, x): + out = self.branches_func(x) + out = self.fuse_func(out) + return out + + +class FuseLayers(nn.Layer): + def __init__(self, + in_channels, + out_channels, + multi_scale_output=True, + name=None, + align_corners=False): + super(FuseLayers, self).__init__() + + self._actual_ch = len(in_channels) if multi_scale_output else 1 + self._in_channels = in_channels + self.align_corners = align_corners + + self.residual_func_list = [] + for i in range(self._actual_ch): + for j in range(len(in_channels)): + if j > i: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}".format(name, i + 1, j + 1), + ConvBN( + in_channels=in_channels[j], + out_channels=out_channels[i], + kernel_size=1, + padding='same', + bias_attr=False)) + self.residual_func_list.append(residual_func) + elif j < i: + pre_num_filters = in_channels[j] + for k in range(i - j): + if k == i - j - 1: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}_{}".format( + name, i + 1, j + 1, k + 1), + ConvBN( + in_channels=pre_num_filters, + out_channels=out_channels[i], + kernel_size=3, + stride=2, + padding='same', + bias_attr=False)) + pre_num_filters = out_channels[i] + else: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}_{}".format( + name, i + 1, j + 1, k + 1), + ConvBNReLU( + in_channels=pre_num_filters, + out_channels=out_channels[j], + kernel_size=3, + stride=2, + padding='same', + bias_attr=False)) + pre_num_filters = out_channels[j] + self.residual_func_list.append(residual_func) + + def forward(self, x): + outs = [] + residual_func_idx = 0 + for i in range(self._actual_ch): + residual = x[i] + residual_shape = residual.shape[-2:] + for j in range(len(self._in_channels)): + if j > i: + y = self.residual_func_list[residual_func_idx](x[j]) + residual_func_idx += 1 + + y = F.interpolate( + y, + residual_shape, + mode='bilinear', + align_corners=self.align_corners) + residual = residual + y + elif j < i: + y = x[j] + for k in range(i - j): + y = self.residual_func_list[residual_func_idx](y) + residual_func_idx += 1 + + residual = residual + y + + residual = F.relu(residual) + outs.append(residual) + + return outs + + +def HRNet_W18_Small_V1(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[1], + stage1_num_channels=[32], + stage2_num_modules=1, + stage2_num_blocks=[2, 2], + stage2_num_channels=[16, 32], + stage3_num_modules=1, + stage3_num_blocks=[2, 2, 2], + stage3_num_channels=[16, 32, 64], + stage4_num_modules=1, + stage4_num_blocks=[2, 2, 2, 2], + stage4_num_channels=[16, 32, 64, 128], + **kwargs) + return model + + +def HRNet_W18_Small_V2(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[2], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[2, 2], + stage2_num_channels=[18, 36], + stage3_num_modules=1, + stage3_num_blocks=[2, 2, 2], + stage3_num_channels=[18, 36, 72], + stage4_num_modules=1, + stage4_num_blocks=[2, 2, 2, 2], + stage4_num_channels=[18, 36, 72, 144], + **kwargs) + return model + + +def HRNet_W18(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[18, 36], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[18, 36, 72], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[18, 36, 72, 144], + **kwargs) + return model + + +def HRNet_W30(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[30, 60], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[30, 60, 120], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[30, 60, 120, 240], + **kwargs) + return model + + +def HRNet_W32(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[32, 64], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[32, 64, 128], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[32, 64, 128, 256], + **kwargs) + return model + + +def HRNet_W40(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[40, 80], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[40, 80, 160], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[40, 80, 160, 320], + **kwargs) + return model + + +def HRNet_W44(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[44, 88], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[44, 88, 176], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[44, 88, 176, 352], + **kwargs) + return model + + +def HRNet_W48(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[48, 96], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[48, 96, 192], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[48, 96, 192, 384], + **kwargs) + return model + + +def HRNet_W60(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[60, 120], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[60, 120, 240], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[60, 120, 240, 480], + **kwargs) + return model + + +def HRNet_W64(**kwargs): + model = HRNet( + stage1_num_modules=1, + stage1_num_blocks=[4], + stage1_num_channels=[64], + stage2_num_modules=1, + stage2_num_blocks=[4, 4], + stage2_num_channels=[64, 128], + stage3_num_modules=4, + stage3_num_blocks=[4, 4, 4], + stage3_num_channels=[64, 128, 256], + stage4_num_modules=3, + stage4_num_blocks=[4, 4, 4, 4], + stage4_num_channels=[64, 128, 256, 512], + **kwargs) + return model diff --git a/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/layers.py b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/layers.py new file mode 100644 index 00000000..be83f19f --- /dev/null +++ b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/model/layers.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def SyncBatchNorm(*args, **kwargs): + """In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm instead""" + if paddle.get_device() == 'cpu': + return nn.BatchNorm(*args, **kwargs) + else: + return nn.SyncBatchNorm(*args, **kwargs) + + +class ConvBNReLU(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): + super().__init__() + + self._conv = nn.Conv2D( + in_channels, out_channels, kernel_size, padding=padding, **kwargs) + + self._batch_norm = SyncBatchNorm(out_channels) + + def forward(self, x): + x = self._conv(x) + x = self._batch_norm(x) + x = F.relu(x) + return x + + +class ConvBN(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding='same', + **kwargs): + super().__init__() + self._conv = nn.Conv2D( + in_channels, out_channels, kernel_size, padding=padding, **kwargs) + self._batch_norm = SyncBatchNorm(out_channels) + + def forward(self, x): + x = self._conv(x) + x = self._batch_norm(x) + return x + diff --git a/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/module.py b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/module.py new file mode 100644 index 00000000..ffe691a4 --- /dev/null +++ b/modules/image/semantic_segmentation/FCN_HRNet_W18_Face_Seg/module.py @@ -0,0 +1,156 @@ +import os +import cv2 +import paddle +import paddle.nn as nn +import numpy as np +from FCN_HRNet_W18_Face_Seg.model import FCN, HRNet_W18 +from paddlehub.module.module import moduleinfo + +@moduleinfo( + name="FCN_HRNet_W18_Face_Seg", # 模型名称 + type="CV", # 模型类型 + author="jm12138", # 作者名称 + author_email="jm12138@qq.com", # 作者邮箱 + summary="FCN_HRNet_W18_Face_Seg", # 模型介绍 + version="1.0.0" # 版本号 +) +class FCN_HRNet_W18_Face_Seg(nn.Layer): + def __init__(self): + super(FCN_HRNet_W18_Face_Seg, self).__init__() + # 加载分割模型 + self.seg = FCN(num_classes=2, backbone=HRNet_W18()) + + # 加载模型参数 + state_dict = paddle.load(os.path.join(self.directory, 'seg_model_384.pdparams')) + self.seg.set_state_dict(state_dict) + + # 设置模型为评估模式 + self.seg.eval() + + # 读取数据函数 + @staticmethod + def load_datas(paths, images): + datas = [] + + # 读取数据列表 + if paths is not None: + for im_path in paths: + assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path) + im = cv2.imread(im_path) + datas.append(im) + + if images is not None: + datas = images + + # 返回数据列表 + return datas + + # 数据预处理函数 + @staticmethod + def preprocess(images, batch_size): + input_datas = [] + + for image in images: + # 图像缩放 + image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA) + + # 数据格式转换 + image = (image / 255.)[np.newaxis, :, :, :] + image = np.transpose(image, (0, 3, 1, 2)).astype(np.float32) + + input_datas.append(image) + + input_datas = np.concatenate(input_datas, 0) + + # 数据切分 + datas_num = input_datas.shape[0] + split_num = datas_num//batch_size+1 if datas_num%batch_size!=0 else datas_num//batch_size + input_datas = np.array_split(input_datas, split_num) + + return input_datas + + + # 结果归一化函数 + @staticmethod + def normPRED(d): + ma = np.max(d) + mi = np.min(d) + + dn = (d-mi)/(ma-mi) + + return dn + + # 结果后处理函数 + def postprocess(self, outputs, datas, output_dir, visualization): + # 检查输出目录 + if visualization: + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + results = [] + + for output, image, i in zip(outputs, datas, range(len(datas))): + # 计算MASK + pred = self.normPRED(output[1]) + + # 图像缩放 + h, w = image.shape[:2] + mask = cv2.resize(pred, (w, h)) + mask[mask<0.5] = 0. + mask[mask>0.55] = 1. + + # 计算输出图像 + face = (image * mask[..., np.newaxis] + (1 - mask[..., np.newaxis]) * 255).astype(np.uint8) + + # 格式还原 + mask = (mask * 255).astype(np.uint8) + + # 可视化结果保存 + if visualization: + cv2.imwrite(os.path.join(output_dir, 'result_mask_%d.png' % i), mask) + cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), face) + + results.append({ + 'mask': mask, + 'face': face + }) + + return results + + # 模型预测函数 + def predict(self, input_datas): + outputs = [] + + for data in input_datas: + # 转换数据为Tensor + data = paddle.to_tensor(data) + + # 模型前向计算 + logits = self.seg(data) + + outputs.append(logits[0].numpy()) + + outputs = np.concatenate(outputs, 0) + + return outputs + + def Segmentation( + self, + images=None, + paths=None, + batch_size=1, + output_dir='output', + visualization=False): + # 获取输入数据 + datas = self.load_datas(paths, images) + + # 数据预处理 + input_datas = self.preprocess(datas, batch_size) + + # 模型预测 + outputs = self.predict(input_datas) + + # 结果后处理 + results = self.postprocess(outputs, datas, output_dir, visualization) + + return results \ No newline at end of file -- GitLab