未验证 提交 792b38a0 编写于 作者: H haoqiang 提交者: GitHub

Add photo2cartoon model (#117)

* Add photo2cartoon model
* Resolve conflicts
* Remove comments
* Add photo2cartoon tutorial
* update p2c tutorials
上级 5519d095
......@@ -86,6 +86,14 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
<img src='./docs/imgs/ugatit.png'width='700' height='250'/>
</div>
### Realistic face cartoonization
<div align='center'>
<img src='./docs/imgs/photo2cartoon.png'width='700' height='250'/>
</div>
### Photo animation
<div align='center'>
......
......@@ -98,6 +98,14 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
<img src='./docs/imgs/ugatit.png'width='700' height='250'/>
</div>
### 写实人像卡通化
<div align='center'>
<img src='./docs/imgs/photo2cartoon.png'width='700' height='250'/>
</div>
### 照片动漫化
<div align='center'>
......
epochs: 300
output_dir: output_dir
adv_weight: 1.0
cycle_weight: 50.0
identity_weight: 10.0
cam_weight: 1000.0
model:
name: UGATITModel
generator:
name: ResnetUGATITP2CGenerator
input_nc: 3
output_nc: 3
ngf: 32
n_blocks: 4
img_size: 256
light: True
discriminator_g:
name: UGATITDiscriminator
input_nc: 3
ndf: 32
n_layers: 7
discriminator_l:
name: UGATITDiscriminator
input_nc: 3
ndf: 32
n_layers: 5
dataset:
train:
name: UnpairedDataset
dataroot: data/photo2cartoon
num_workers: 0
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bilinear' #'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/photo2cartoon/testA
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bilinear' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
weight_decay: 0.0001
lr_scheduler:
name: linear
learning_rate: 0.0001
start_epoch: 150
decay_epochs: 150
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 30
# Photo2cartoon
## 1 Principle
The aim of portrait cartoon stylization is to transform real photos into cartoon images with portrait's ID information and texture details. We use Generative Adversarial Network method to realize the mapping of picture to cartoon. Considering the difficulty in obtaining paired data and the non-corresponding shape of input and output, we adopt unpaired image translation fashion.
Recently, Kim et al. propose a novel normalization function (AdaLIN) and an attention module in paper "U-GAT-IT" and achieve exquisite selfie2anime results. Different from the exaggerated anime style, our cartoon style is more realistic and contains unequivocal ID information.
We propose a Soft Adaptive Layer-Instance Normalization (Soft-AdaLIN) method which fuses the statistics of encoding features and decoding features in de-standardization.
Based on U-GAT-IT, two hourglass modules are introduced before encoder and after decoder to improve the performance in a progressively way.
Different from the exaggerated anime style, our cartoon style is more realistic and contains unequivocal ID information. In original [project](https://github.com/minivision-ai/photo2cartoon), we add a Face ID Loss (cosine distance of ID features between input image and cartoon image) to reach identity invariance. (Face ID Loss is not added in this repo, please refer to photo2cartoon)
![](../../imgs/photo2cartoon_pipeline.png)
We also pre-process the data to a fixed pattern to help reduce the difficulty of optimization. For details, see below.
![](../../imgs/photo2cartoon_data_process.jpg)
## 2 How to use
### 2.1 Test
```
from ppgan.apps import Photo2CartoonPredictor
p2c = Photo2CartoonPredictor()
p2c.run('test_img.jpg')
```
### 2.2 Train
Prepare Datasets:
Training data contains portrait photos (domain A) and cartoon images (domain B), and can be downloaded from [baidu driver](https://pan.baidu.com/s/1RqB4MNMAY_yyXAIS3KBXqw)(password: fo8u).
The structure of dataset is as following:
```
├── data
└── photo2cartoon
├── trainA
├── trainB
├── testA
└── testB
```
Train:
```
python -u tools/main.py --config-file configs/ugatit_photo2cartoon.yaml
```
## 3 Results
![](../../imgs/photo2cartoon.png)
## 4 Download
| model | link |
|---|---|
| photo2cartoon_genA2B | [photo2cartoon_genA2B](https://paddlegan.bj.bcebos.com/models/photo2cartoon_genA2B_weight.pdparams)
# References
- [U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation](https://arxiv.org/abs/1907.10830)
```
@inproceedings{Kim2020U-GAT-IT:,
title={U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation},
author={Junho Kim and Minjae Kim and Hyeonwoo Kang and Kwang Hee Lee},
booktitle={International Conference on Learning Representations},
year={2020}
}
```
# Authors
[minivision-ai](https://github.com/minivision-ai)[haoqiang](https://github.com/hao-qiang)
# Photo2cartoon
## 1 原理介绍
人像卡通风格渲染的目标是,在保持原图像ID信息和纹理细节的同时,将真实照片转换为卡通风格的非真实感图像。一般而言,基于成对数据的pix2pix方法能达到较好的图像转换效果,但本任务的输入输出轮廓并非一一对应,例如卡通风格的眼睛更大、下巴更瘦;且成对的数据绘制难度大、成本较高,因此我们采用unpaired image translation方法来实现。
近期的论文U-GAT-IT提出了一种归一化方法——AdaLIN,能够自动调节Instance Norm和Layer Norm的比重,再结合attention机制能够实现精美的人像日漫风格转换。为了实现写实的人像卡通化风格,我们对U-GAT-IT进行了定制化的修改。
我们提出了一种Soft-AdaLIN(Soft Adaptive Layer-Instance Normalization)归一化方法,在反规范化时将编码器的均值方差(照片特征)与解码器的均值方差(卡通特征)相融合。
模型结构方面,在U-GAT-IT的基础上,我们在编码器之前和解码器之后各增加了2个hourglass模块,渐进地提升模型特征抽象和重建能力。
[原项目](https://github.com/minivision-ai/photo2cartoon)中我们还增加了Face ID Loss,使用预训练的人脸识别模型提取照片和卡通画的ID特征,通过余弦距离来约束生成的卡通画,使其更像本人。(paddle版本中暂时未加入Face ID Loss,请参见原项目)
![](../../imgs/photo2cartoon_pipeline.png)
由于实验数据较为匮乏,为了降低训练难度,我们将数据处理成固定的模式。首先检测图像中的人脸及关键点,根据人脸关键点旋转校正图像,并按统一标准裁剪,再将裁剪后的头像输入人像分割模型(基于PaddleSeg框架训练)去除背景。
![](../../imgs/photo2cartoon_data_process.jpg)
## 2 如何使用
### 2.1 测试
```
from ppgan.apps import Photo2CartoonPredictor
p2c = Photo2CartoonPredictor()
p2c.run('test_img.jpg')
```
### 2.2 训练
数据准备:
模型使用非成对数据训练,下载地址:[百度网盘](https://pan.baidu.com/s/1RqB4MNMAY_yyXAIS3KBXqw),提取码:fo8u。
数据集组成方式如下:
```
├── data
└── photo2cartoon
├── trainA
├── trainB
├── testA
└── testB
```
训练模型:
```
python -u tools/main.py --config-file configs/ugatit_photo2cartoon.yaml
```
## 3 结果展示
![](../../imgs/photo2cartoon.png)
## 4 模型下载
| 模型 | 下载地址 |
|---|---|
| photo2cartoon_genA2B | [下载链接](https://paddlegan.bj.bcebos.com/models/photo2cartoon_genA2B_weight.pdparams)
# 参考
- [U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation](https://arxiv.org/abs/1907.10830)
```
@inproceedings{Kim2020U-GAT-IT:,
title={U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation},
author={Junho Kim and Minjae Kim and Hyeonwoo Kang and Kwang Hee Lee},
booktitle={International Conference on Learning Representations},
year={2020}
}
```
# 作者
[minivision-ai](https://github.com/minivision-ai)[haoqiang](https://github.com/hao-qiang)
......@@ -21,5 +21,6 @@ from .first_order_predictor import FirstOrderPredictor
from .face_parse_predictor import FaceParsePredictor
from .animegan_predictor import AnimeGANPredictor
from .midas_predictor import MiDaSPredictor
from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
from PIL import Image
import numpy as np
import paddle
from paddle.utils.download import get_path_from_url
from ppgan.faceutils.dlibutils import align_crop
from ppgan.faceutils.face_segmentation import FaceSeg
from ppgan.models.generators import ResnetUGATITP2CGenerator
from .base_predictor import BasePredictor
P2C_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/photo2cartoon_genA2B_weight.pdparams"
class Photo2CartoonPredictor(BasePredictor):
def __init__(self, output_path='output', weight_path=None):
self.output_path = output_path
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(P2C_WEIGHT_URL, cur_path)
self.genA2B = ResnetUGATITP2CGenerator()
params = paddle.load(weight_path)
self.genA2B.set_state_dict(params)
self.genA2B.eval()
self.faceseg = FaceSeg()
def run(self, image_path):
image = Image.open(image_path)
face_image = align_crop(image)
face_mask = self.faceseg(face_image)
face_image = cv2.resize(face_image, (256, 256), interpolation=cv2.INTER_AREA)
face_mask = cv2.resize(face_mask, (256, 256))[:, :, np.newaxis] / 255.
face = (face_image * face_mask + (1 - face_mask) * 255) / 127.5 - 1
face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
face = paddle.to_tensor(face)
# inference
with paddle.no_grad():
cartoon = self.genA2B(face)[0][0]
# post-process
cartoon = np.transpose(cartoon.numpy(), (1, 2, 0))
cartoon = (cartoon + 1) * 127.5
cartoon = (cartoon * face_mask + (1 - face_mask) * 255).astype(np.uint8)
pnoto_save_path = os.path.join(self.output_path, 'p2c_photo.png')
cv2.imwrite(pnoto_save_path, cv2.cvtColor(face_image, cv2.COLOR_RGB2BGR))
cartoon_save_path = os.path.join(self.output_path, 'p2c_cartoon.png')
cv2.imwrite(cartoon_save_path, cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR))
print("Cartoon image has been saved at '{}'.".format(cartoon_save_path))
return cartoon
......@@ -15,3 +15,4 @@
from . import dlibutils as dlib
from . import mask
from . import image
from . import face_segmentation
......@@ -13,3 +13,4 @@
# limitations under the License.
from .dlib_utils import detect, crop, landmarks, crop_from_array
from .face_align import align_crop
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import math
import numpy as np
from PIL import Image
from .dlib_utils import detect, landmarks
def align_crop(image: Image):
faces = detect(image)
assert len(faces) > 0, 'can not detect face!!!'
face = get_max_face(faces)
lms = landmarks(image, face)
lms = lms[:, ::-1]
image = np.array(image)
image_align, landmarks_align = align(image, lms)
image_crop = crop(image_align, landmarks_align)
return image_crop
def get_max_face(faces):
if len(faces) == 1:
return faces[0]
else:
# find max face
areas = []
for face in faces:
left = face.rect.left()
top = face.rect.top()
right = face.rect.right()
bottom = face.rect.bottom()
areas.append((bottom - top) * (right - left))
max_face_index = np.argmax(areas)
return faces[max_face_index]
def align(image, lms):
# rotation angle
left_eye_corner = lms[36]
right_eye_corner = lms[45]
radian = np.arctan((left_eye_corner[1] - right_eye_corner[1]) / (left_eye_corner[0] - right_eye_corner[0]))
# image size after rotating
height, width, _ = image.shape
cos = math.cos(radian)
sin = math.sin(radian)
new_w = int(width * abs(cos) + height * abs(sin))
new_h = int(width * abs(sin) + height * abs(cos))
# translation
Tx = new_w // 2 - width // 2
Ty = new_h // 2 - height // 2
# affine matrix
M = np.array([[cos, sin, (1 - cos) * width / 2. - sin * height / 2. + Tx],
[-sin, cos, sin * width / 2. + (1 - cos) * height / 2. + Ty]])
image_rotate = cv2.warpAffine(image, M, (new_w, new_h), borderValue=(255, 255, 255))
landmarks = np.concatenate([lms, np.ones((lms.shape[0], 1))], axis=1)
landmarks_rotate = np.dot(M, landmarks.T).T
return image_rotate, landmarks_rotate
def crop(image, lms):
lms_top = np.min(lms[:, 1])
lms_bottom = np.max(lms[:, 1])
lms_left = np.min(lms[:, 0])
lms_right = np.max(lms[:, 0])
# expand bbox
top = int(lms_top - 0.8 * (lms_bottom - lms_top))
bottom = int(lms_bottom + 0.3 * (lms_bottom - lms_top))
left = int(lms_left - 0.3 * (lms_right - lms_left))
right = int(lms_right + 0.3 * (lms_right - lms_left))
if bottom - top > right - left:
left -= ((bottom - top) - (right - left)) // 2
right = left + (bottom - top)
else:
top -= ((right - left) - (bottom - top)) // 2
bottom = top + (right - left)
image_crop = np.ones((bottom - top + 1, right - left + 1, 3), np.uint8) * 255
h, w = image.shape[:2]
left_white = max(0, -left)
left = max(0, left)
right = min(right, w - 1)
right_white = left_white + (right - left)
top_white = max(0, -top)
top = max(0, top)
bottom = min(bottom, h - 1)
bottom_white = top_white + (bottom - top)
image_crop[top_white:bottom_white+1, left_white:right_white+1] = image[top:bottom+1, left:right+1].copy()
return image_crop
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .face_seg import FaceSeg
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import cv2
import numpy as np
import paddle
from paddle.utils.download import get_path_from_url
from .fcn import FCN
from .hrnet import HRNet_W18
BISENET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/faceseg_FCN-HRNetW18.pdparams'
class FaceSeg:
def __init__(self):
save_pth = get_path_from_url(BISENET_WEIGHT_URL, osp.split(osp.realpath(__file__))[0])
self.net = FCN(num_classes=2, backbone=HRNet_W18())
state_dict = paddle.load(save_pth)
self.net.set_state_dict(state_dict)
self.net.eval()
def __call__(self, image):
image_input = self.input_transform(image) # RGB image
with paddle.no_grad():
logits = self.net(image_input)
pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy()
mask = np.squeeze(pred).astype(np.uint8)
mask = self.output_transform(mask, shape=image.shape[:2])
return mask
def input_transform(self, image):
image_input = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
image_input = (image_input / 255.)[np.newaxis, :, :, :]
image_input = np.transpose(image_input, (0, 3, 1, 2)).astype(np.float32)
image_input = paddle.to_tensor(image_input)
return image_input
@staticmethod
def output_transform(output, shape):
output = cv2.resize(output, (shape[1], shape[0]))
image_output = np.clip((output * 255), 0, 255).astype(np.uint8)
return image_output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import ConvBNReLU
class FCN(nn.Layer):
"""
A simple implementation for FCN based on PaddlePaddle.
The original article refers to
Evan Shelhamer, et, al. "Fully Convolutional Networks for Semantic Segmentation"
(https://arxiv.org/abs/1411.4038).
Args:
num_classes (int): The unique number of target classes.
backbone (paddle.nn.Layer): Backbone networks.
backbone_indices (tuple, optional): The values in the tuple indicate the indices of output of backbone.
Default: (-1, ).
channels (int, optional): The channels between conv layer and the last layer of FCNHead.
If None, it will be the number of channels of input features. Default: None.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None
"""
def __init__(self,
num_classes,
backbone,
backbone_indices=(-1, ),
channels=None,
align_corners=False,
pretrained=None):
super(FCN, self).__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = FCNHead(num_classes, backbone_indices, backbone_channels,
channels)
self.align_corners = align_corners
self.pretrained = pretrained
def forward(self, x):
feat_list = self.backbone(x)
logit_list = self.head(feat_list)
return [
F.interpolate(
logit,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list
]
class FCNHead(nn.Layer):
"""
A simple implementation for FCNHead based on PaddlePaddle
Args:
num_classes (int): The unique number of target classes.
backbone_indices (tuple, optional): The values in the tuple indicate the indices of output of backbone.
Default: (-1, ).
channels (int, optional): The channels between conv layer and the last layer of FCNHead.
If None, it will be the number of channels of input features. Default: None.
pretrained (str, optional): The path of pretrained model. Default: None
"""
def __init__(self,
num_classes,
backbone_indices=(-1, ),
backbone_channels=(270, ),
channels=None):
super(FCNHead, self).__init__()
self.num_classes = num_classes
self.backbone_indices = backbone_indices
if channels is None:
channels = backbone_channels[0]
self.conv_1 = ConvBNReLU(
in_channels=backbone_channels[0],
out_channels=channels,
kernel_size=1,
padding='same',
stride=1)
self.cls = nn.Conv2D(
in_channels=channels,
out_channels=self.num_classes,
kernel_size=1,
stride=1,
padding=0)
def forward(self, feat_list):
logit_list = []
x = feat_list[self.backbone_indices[0]]
x = self.conv_1(x)
logit = self.cls(x)
logit_list.append(logit)
return logit_list
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def SyncBatchNorm(*args, **kwargs):
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm instead"""
if paddle.get_device() == 'cpu':
return nn.BatchNorm(*args, **kwargs)
else:
return nn.SyncBatchNorm(*args, **kwargs)
class ConvBNReLU(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super().__init__()
self._conv = nn.Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
x = F.relu(x)
return x
class ConvBN(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super().__init__()
self._conv = nn.Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
return x
......@@ -21,5 +21,6 @@ from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator
from .generater_animegan import AnimeGenerator, AnimeGeneratorLite
from .wav2lip import Wav2Lip
from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator
from .generator_styleganv2 import StyleGANv2Generator
from .generator_pixel2style2pixel import Pixel2Style2Pixel
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
@GENERATORS.register()
class ResnetUGATITP2CGenerator(nn.Layer):
def __init__(self,
input_nc=3,
output_nc=3,
ngf=32,
img_size=256,
n_blocks=4,
light=True):
super(ResnetUGATITP2CGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.n_blocks = n_blocks
self.img_size = img_size
self.light = light
DownBlock = []
DownBlock += [
nn.Pad2D([3, 3, 3, 3], 'reflect'),
nn.Conv2D(input_nc, ngf, kernel_size=7, stride=1, bias_attr=False),
nn.InstanceNorm2D(ngf, weight_attr=False, bias_attr=False),
nn.ReLU()
]
DownBlock += [
HourGlass(ngf, ngf),
HourGlass(ngf, ngf)
]
# Down-Sampling
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
DownBlock += [
nn.Pad2D([1, 1, 1, 1], 'reflect'),
nn.Conv2D(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, bias_attr=False),
nn.InstanceNorm2D(ngf*mult*2, weight_attr=False, bias_attr=False),
nn.ReLU()
]
# Encoder Bottleneck
mult = 2 ** n_downsampling
for i in range(n_blocks):
setattr(self, 'EncodeBlock'+str(i+1), ResnetBlock(ngf*mult))
# Class Activation Map
self.gap_fc = nn.Linear(ngf*mult, 1, bias_attr=False)
self.gmp_fc = nn.Linear(ngf*mult, 1, bias_attr=False)
self.conv1x1 = nn.Conv2D(ngf*mult*2, ngf*mult, kernel_size=1, stride=1)
self.relu = nn.ReLU()
# Gamma, Beta block
FC = []
if self.light:
FC += [
nn.Linear(ngf*mult, ngf*mult, bias_attr=False),
nn.ReLU(),
nn.Linear(ngf*mult, ngf*mult, bias_attr=False),
nn.ReLU()
]
else:
FC += [
nn.Linear(img_size//mult*img_size//mult*ngf*mult, ngf*mult, bias_attr=False),
nn.ReLU(),
nn.Linear(ngf*mult, ngf*mult, bias_attr=False),
nn.ReLU()
]
# Decoder Bottleneck
mult = 2 ** n_downsampling
for i in range(n_blocks):
setattr(self, 'DecodeBlock'+str(i + 1), ResnetSoftAdaLINBlock(ngf*mult))
# Up-Sampling
UpBlock = []
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
UpBlock += [
nn.Upsample(scale_factor=2),
nn.Pad2D([1, 1, 1, 1], 'reflect'),
nn.Conv2D(ngf*mult, ngf*mult//2, kernel_size=3, stride=1, bias_attr=False),
LIN(ngf*mult//2),
nn.ReLU()
]
UpBlock += [
HourGlass(ngf, ngf),
HourGlass(ngf, ngf, False)
]
UpBlock += [
nn.Pad2D([3, 3, 3, 3], 'reflect'),
nn.Conv2D(3, output_nc, kernel_size=7, stride=1, bias_attr=False),
nn.Tanh()
]
self.DownBlock = nn.Sequential(*DownBlock)
self.FC = nn.Sequential(*FC)
self.UpBlock = nn.Sequential(*UpBlock)
def forward(self, x):
bs = x.shape[0]
x = self.DownBlock(x)
content_features = []
for i in range(self.n_blocks):
x = getattr(self, 'EncodeBlock'+str(i+1))(x)
content_features.append(F.adaptive_avg_pool2d(x, 1).reshape([bs, -1]))
gap = F.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.reshape([bs, -1]))
gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0])
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = F.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.reshape([bs, -1]))
gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0])
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = paddle.concat([gap_logit, gmp_logit], 1)
x = paddle.concat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = paddle.sum(x, axis=1, keepdim=True)
if self.light:
x_ = F.adaptive_avg_pool2d(x, 1)
style_features = self.FC(x_.reshape([bs, -1]))
else:
style_features = self.FC(x.reshape([bs, -1]))
for i in range(self.n_blocks):
x = getattr(self, 'DecodeBlock'+str(i+1))(x, content_features[4-i-1], style_features)
out = self.UpBlock(x)
return out, cam_logit, heatmap
class ConvBlock(nn.Layer):
def __init__(self, dim_in, dim_out):
super(ConvBlock, self).__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.conv_block1 = self.__convblock(dim_in, dim_out//2)
self.conv_block2 = self.__convblock(dim_out//2, dim_out//4)
self.conv_block3 = self.__convblock(dim_out//4, dim_out//4)
if self.dim_in != self.dim_out:
self.conv_skip = nn.Sequential(
nn.InstanceNorm2D(dim_in, weight_attr=False, bias_attr=False),
nn.ReLU(),
nn.Conv2D(dim_in, dim_out, kernel_size=1, stride=1, bias_attr=False)
)
@staticmethod
def __convblock(dim_in, dim_out):
return nn.Sequential(
nn.InstanceNorm2D(dim_in, weight_attr=False, bias_attr=False),
nn.ReLU(),
nn.Pad2D([1, 1, 1, 1], 'reflect'),
nn.Conv2D(dim_in, dim_out, kernel_size=3, stride=1, bias_attr=False)
)
def forward(self, x):
residual = x
x1 = self.conv_block1(x)
x2 = self.conv_block2(x1)
x3 = self.conv_block3(x2)
out = paddle.concat([x1, x2, x3], 1)
if self.dim_in != self.dim_out:
residual = self.conv_skip(residual)
return residual + out
class HourGlassBlock(nn.Layer):
def __init__(self, dim_in):
super(HourGlassBlock, self).__init__()
self.n_skip = 4
self.n_block = 9
for i in range(self.n_skip):
setattr(self, 'ConvBlockskip'+str(i+1), ConvBlock(dim_in, dim_in))
for i in range(self.n_block):
setattr(self, 'ConvBlock'+str(i+1), ConvBlock(dim_in, dim_in))
def forward(self, x):
skips = []
for i in range(self.n_skip):
skips.append(getattr(self, 'ConvBlockskip'+str(i+1))(x))
x = F.avg_pool2d(x, 2)
x = getattr(self, 'ConvBlock'+str(i+1))(x)
x = self.ConvBlock5(x)
for i in range(self.n_skip):
x = getattr(self, 'ConvBlock'+str(i+6))(x)
x = F.upsample(x, scale_factor=2)
x = skips[self.n_skip-i-1] + x
return x
class HourGlass(nn.Layer):
def __init__(self, dim_in, dim_out, use_res=True):
super(HourGlass, self).__init__()
self.use_res = use_res
self.HG = nn.Sequential(
HourGlassBlock(dim_in),
ConvBlock(dim_out, dim_out),
nn.Conv2D(dim_out, dim_out, kernel_size=1, stride=1, bias_attr=False),
nn.InstanceNorm2D(dim_out, weight_attr=False, bias_attr=False),
nn.ReLU()
)
self.Conv1 = nn.Conv2D(dim_out, 3, kernel_size=1, stride=1)
if self.use_res:
self.Conv2 = nn.Conv2D(dim_out, dim_out, kernel_size=1, stride=1)
self.Conv3 = nn.Conv2D(3, dim_out, kernel_size=1, stride=1)
def forward(self, x):
ll = self.HG(x)
tmp_out = self.Conv1(ll)
if self.use_res:
ll = self.Conv2(ll)
tmp_out_ = self.Conv3(tmp_out)
return x + ll + tmp_out_
else:
return tmp_out
class ResnetBlock(nn.Layer):
def __init__(self, dim, use_bias=False):
super(ResnetBlock, self).__init__()
conv_block = []
conv_block += [
nn.Pad2D([1, 1, 1, 1], 'reflect'),
nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias),
nn.InstanceNorm2D(dim, weight_attr=False, bias_attr=False),
nn.ReLU()
]
conv_block += [
nn.Pad2D([1, 1, 1, 1], 'reflect'),
nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias),
nn.InstanceNorm2D(dim, weight_attr=False, bias_attr=False)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class ResnetSoftAdaLINBlock(nn.Layer):
def __init__(self, dim, use_bias=False):
super(ResnetSoftAdaLINBlock, self).__init__()
self.pad1 = nn.Pad2D([1, 1, 1, 1], 'reflect')
self.conv1 = nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias)
self.norm1 = SoftAdaLIN(dim)
self.relu1 = nn.ReLU()
self.pad2 = nn.Pad2D([1, 1, 1, 1], 'reflect')
self.conv2 = nn.Conv2D(dim, dim, kernel_size=3, stride=1, bias_attr=use_bias)
self.norm2 = SoftAdaLIN(dim)
def forward(self, x, content_features, style_features):
out = self.pad1(x)
out = self.conv1(out)
out = self.norm1(out, content_features, style_features)
out = self.relu1(out)
out = self.pad2(out)
out = self.conv2(out)
out = self.norm2(out, content_features, style_features)
return out + x
class SoftAdaLIN(nn.Layer):
def __init__(self, num_features, eps=1e-5):
super(SoftAdaLIN, self).__init__()
self.norm = AdaLIN(num_features, eps)
self.w_gamma = self.create_parameter([1, num_features], default_initializer=nn.initializer.Constant(0.))
self.w_beta = self.create_parameter([1, num_features], default_initializer=nn.initializer.Constant(0.))
self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features, bias_attr=False),
nn.ReLU(),
nn.Linear(num_features, num_features, bias_attr=False))
self.c_beta = nn.Sequential(nn.Linear(num_features, num_features, bias_attr=False),
nn.ReLU(),
nn.Linear(num_features, num_features, bias_attr=False))
self.s_gamma = nn.Linear(num_features, num_features, bias_attr=False)
self.s_beta = nn.Linear(num_features, num_features, bias_attr=False)
def forward(self, x, content_features, style_features):
content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features)
style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features)
w_gamma_, w_beta_ = self.w_gamma.expand([x.shape[0], -1]), self.w_beta.expand([x.shape[0], -1])
soft_gamma = (1. - w_gamma_) * style_gamma + w_gamma_ * content_gamma
soft_beta = (1. - w_beta_) * style_beta + w_beta_ * content_beta
out = self.norm(x, soft_gamma, soft_beta)
return out
class AdaLIN(nn.Layer):
def __init__(self, num_features, eps=1e-5):
super(AdaLIN, self).__init__()
self.eps = eps
self.rho = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(0.9))
def forward(self, x, gamma, beta):
in_mean, in_var = paddle.mean(x, axis=[2, 3], keepdim=True), paddle.var(x, axis=[2, 3], keepdim=True)
out_in = (x - in_mean) / paddle.sqrt(in_var + self.eps)
ln_mean, ln_var = paddle.mean(x, axis=[1, 2, 3], keepdim=True), paddle.var(x, axis=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / paddle.sqrt(ln_var + self.eps)
out = self.rho.expand([x.shape[0], -1, -1, -1]) * out_in + \
(1-self.rho.expand([x.shape[0], -1, -1, -1])) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
class LIN(nn.Layer):
def __init__(self, num_features, eps=1e-5):
super(LIN, self).__init__()
self.eps = eps
self.rho = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(0.))
self.gamma = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(1.))
self.beta = self.create_parameter([1, num_features, 1, 1], default_initializer=nn.initializer.Constant(0.))
def forward(self, x):
in_mean, in_var = paddle.mean(x, axis=[2, 3], keepdim=True), paddle.var(x, axis=[2, 3], keepdim=True)
out_in = (x - in_mean) / paddle.sqrt(in_var + self.eps)
ln_mean, ln_var = paddle.mean(x, axis=[1, 2, 3], keepdim=True), paddle.var(x, axis=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / paddle.sqrt(ln_var + self.eps)
out = self.rho.expand([x.shape[0], -1, -1, -1]) * out_in + \
(1-self.rho.expand([x.shape[0], -1, -1, -1])) * out_ln
out = out * self.gamma.expand([x.shape[0], -1, -1, -1]) + self.beta.expand([x.shape[0], -1, -1, -1])
return out
......@@ -79,3 +79,14 @@ class RhoClipper(object):
w = module.rho
w = w.clip(self.clip_min, self.clip_max)
module.rho.set_value(w)
# used for photo2cartoon training
if hasattr(module, 'w_gamma'):
w = module.w_gamma
w = w.clip(self.clip_min, self.clip_max)
module.w_gamma.set_value(w)
if hasattr(module, 'w_beta'):
w = module.w_beta
w = w.clip(self.clip_min, self.clip_max)
module.w_beta.set_value(w)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册