未验证 提交 10de3c64 编写于 作者: 郑启航 提交者: GitHub

Add AnimeGANv2 model (#102)

* add animeganv2 network and dataset
* animegan:refine code,add License
Co-authored-by: Nqingqing01 <dangqingqing@baidu.com>
上级 8ff47a50
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import AnimeGANPredictor
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_image", type=str, help="path to source image")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--use_adjust_brightness",
action="store_false",
help="adjust brightness mode.")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = AnimeGANPredictor(args.output_path, args.weight_path,
args.use_adjust_brightness)
predictor.run(args.input_image)
epochs: 30
output_dir: output_dir
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
con_weight: 1.5
sty_weight: 2.5
color_weight: 10.
tv_weight: 1.
model:
name: AnimeGANV2Model
generator:
name: AnimeGenerator
discriminator:
name: AnimeDiscriminator
gan_mode: lsgan
dataset:
train:
name: AnimeGANV2Dataset
num_workers: 4
batch_size: 4
dataroot: data/animedataset
style: Hayao
phase: train
direction: AtoB
transform_real:
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
transform_anime:
- name: Add
value: [-4.4346957, -8.665916, 13.100612]
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
transform_gray:
- name: Grayscale
num_output_channels: 3
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/animedataset/test/HR_photo
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: ResizeToScale
size: [256, 256]
scale: 32
interpolation: bilinear
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.00002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 100
snapshot_config:
interval: 5
epochs: 2
output_dir: output_dir
con_weight: 1
pretrain_ckpt: null
model:
name: AnimeGANV2PreTrainModel
generator:
name: AnimeGenerator
discriminator:
name: AnimeDiscriminator
gan_mode: lsgan
dataset:
train:
name: AnimeGANV2Dataset
num_workers: 4
batch_size: 4
dataroot: data/animedataset
style: Hayao
phase: train
direction: AtoB
transform_real:
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
transform_anime:
- name: Add
value: [-4.4346957, -8.665916, 13.100612]
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
transform_gray:
- name: Grayscale
num_output_channels: 3
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/animedataset/test/test_photo
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: Resize
size: [256, 256]
interpolation: "bicubic" #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 100
visiual_interval: 100
snapshot_config:
interval: 5
# 1 AnimeGANv2
## 1.1 Introduction
[AnimeGAN](https://github.com/TachibanaYoshino/AnimeGANv2) improved the [CVPR paper CartoonGAN](https://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_CartoonGAN_Generative_Adversarial_CVPR_2018_paper.pdf), mainly to solve the over-stylized and color artifact area. For the details, you can refer to the [Zhihu article](https://zhuanlan.zhihu.com/p/76574388?from_voters_page=true) writes by the paper author.Based on the AnimeGAN, the AnimeGANv2 add the `total variation loss` in the generator loss.
## 1.2 How to use
### 1.2.1 Quick start
After installing PaddleGAN, you can run python code as follows to generate the stylized image. Where the `PATH_OF_IMAGE` is your source image path.
```python
from ppgan.apps import AnimeGANPredictor
predictor = AnimeGANPredictor()
predictor.run(PATH_OF_IMAGE)
```
Or run such a command to get the same result:
```sh
python applications/tools/animeganv2.py --input_image ${PATH_OF_IMAGE}
```
### 1.2.1 Prepare dataset
We download the dataset provided by the author from [here](https://github.com/TachibanaYoshino/AnimeGAN/releases/tag/dataset-1).Then unzip to the `data` directory.
```sh
wget https://github.com/TachibanaYoshino/AnimeGAN/releases/download/dataset-1/dataset.zip
cd PaddleGAN
unzip YOUR_DATASET_DIR/dataset.zip -d data/animedataset
```
For example, the structure of `animedataset` is as following:
```sh
animedataset
├── Hayao
│ ├── smooth
│ └── style
├── Paprika
│ ├── smooth
│ └── style
├── Shinkai
│ ├── smooth
│ └── style
├── SummerWar
│ ├── smooth
│ └── style
├── test
│ ├── HR_photo
│ ├── label_map
│ ├── real
│ ├── test_photo
│ └── test_photo256
├── train_photo
└── val
```
### 1.2.2 Training
An example is training to Hayao stylize.
1. To ensure the generator can generate the original image, we need to warmup the model.:
```sh
python tools/main.py --config-file configs/animeganv2_pretrain.yaml
```
2. After the warmup, we strat to training GAN.:
**NOTE:** you must modify the `configs/animeganv2.yaml > pretrain_ckpt ` parameter first! ensure the GAN can reuse the warmup generator model.
Set the `batch size=4` and the `learning rate=0.00002`. Train 30 epochs on a GTX2060S GPU to reproduce the result. For other hyperparameters, please refer to `configs/animeganv2.yaml`.
```sh
python tools/main.py --config-file configs/animeganv2.yaml
```
3. Change target style
Modify `style` parameter in the `configs/animeganv2.yaml`, now support choice from `Hayao, Paprika, Shinkai, SummerWar`. If you want to use your own dataset, you can modify it to be your own in the configuration file.
**NOTE :** After modifying the target style, calculate the mean value of the target style dataset at first, and the `transform_anime->Add->value` parameter in `configs/animeganv2.yaml` must be modified.
The following example shows how to obtain the mean value of the `Hayao` style:
```sh
python tools/animegan_picmean.py --dataset data/animedataset/Hayao/style
image_num: 1792
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1792/1792 [00:04<00:00, 444.95it/s]
RGB mean diff
[-4.4346957 -8.665916 13.100612 ]
```
### 1.2.3 Test
test model on `data/animedataset/test/HR_photo`
```sh
python tools/main.py --config-file configs/animeganv2.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 1.3 Results
| original image | style image |
| ----------------------------------- | ---------------------------------- |
| ![](../../imgs/animeganv2_test.jpg) | ![](../../imgs/animeganv2_res.jpg) |
......@@ -351,3 +351,24 @@ ppgan.apps.FaceParsePredictor(output_path='output')
> ```
> **返回值:**
> > - mask(numpy.ndarray): 返回解析完成的人脸成分mask矩阵, 数据类型为numpy.ndarray
## ppgan.apps.AnimeGANPredictor
```pyhton
ppgan.apps.AnimeGANPredictor(output_path='output_dir',weight_path=None,use_adjust_brightness=True)
```
> 利用animeganv2来对景物图像进行动漫风格化。论文是 AnimeGAN: A Novel Lightweight GAN for Photo Animation, 论文链接: https://link.springer.com/chapter/10.1007/978-981-15-5577-0_18.
> **参数:**
>
> > - input_image: 输入待解析的图片文件路径
> **示例:**
>
> ```
> from ppgan.apps import AnimeGANPredictor
> predictor = AnimeGANPredictor()
> predictor.run('docs/imgs/animeganv2_test.jpg')
> ```
> **返回值:**
> > - anime_image(numpy.ndarray): 返回风格化后的景色图像
# 1 AnimeGANv2
## 1.1 原理介绍
[AnimeGAN](https://github.com/TachibanaYoshino/AnimeGANv2)基于2018年[CVPR论文CartoonGAN](https://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_CartoonGAN_Generative_Adversarial_CVPR_2018_paper.pdf)基础上对其进行了一些改进,主要消除了过度风格化以及颜色伪影区域的问题。对于具体原理可以参见作者[知乎文章](https://zhuanlan.zhihu.com/p/76574388?from_voters_page=true)。AnimeGANv2是作者在AnimeGAN的基础上添加了`total variation loss`的新模型。
## 1.2 如何使用
### 1.2.1 快速体验
安装`PaddleGAN`之后运行如下代码即生成风格化后的图像`output_dir/anime.png`,其中`PATH_OF_IMAGE`为你需要转换的图像路径。
```python
from ppgan.apps import AnimeGANPredictor
predictor = AnimeGANPredictor()
predictor.run(PATH_OF_IMAGE)
```
或者在终端中运行如下命令,也可获得相同结果:
```sh
python applications/tools/animeganv2.py --input_image ${PATH_OF_IMAGE}
```
### 1.2.1 数据准备
我们下载作者提供的训练数据,训练数据可以从[这里](https://github.com/TachibanaYoshino/AnimeGAN/releases/tag/dataset-1)下载。
下载后解压到data目录下:
```sh
wget https://github.com/TachibanaYoshino/AnimeGAN/releases/download/dataset-1/dataset.zip
cd PaddleGAN
unzip YOUR_DATASET_DIR/dataset.zip -d data/animedataset
```
解压完成后数据分布如下所示:
```sh
animedataset
├── Hayao
│ ├── smooth
│ └── style
├── Paprika
│ ├── smooth
│ └── style
├── Shinkai
│ ├── smooth
│ └── style
├── SummerWar
│ ├── smooth
│ └── style
├── test
│ ├── HR_photo
│ ├── label_map
│ ├── real
│ ├── test_photo
│ └── test_photo256
├── train_photo
└── val
```
### 1.2.2 训练
示例以训练Hayao风格的数据为例。
1. 为了保证模型具备生成原图的效果,需要预热模型:
```sh
python tools/main.py --config-file configs/animeganv2_pretrain.yaml
```
1. 预热模型完成后,训练风格迁移模型:
**注意:** 必须先修改在`configs/animeganv2.yaml`中的`pretrain_ckpt`参数,确保指向正确的 **预热模型权重路径**
设置`batch size=4``learning rate=0.00002`,在一个 GTX2060S GPU上训练30个epoch即可获得较好的效果,其他超参数请参考`configs/animeganv2.yaml`
```sh
python tools/main.py --config-file configs/animeganv2.yaml
```
1. 改变目标图像的风格
修改`configs/animeganv2.yaml`中的`style`参数即可改变风格(目前可选择`Hayao,Paprika,Shinkai,SummerWar`)。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。
**注意:** 修改目标风格后,必须计算目标风格数据集的像素均值,并修改`configs/animeganv2.yaml`中的`transform_anime->Add->value`参数。
如下例子展示了如何获得`Hayao`风格图像的像素均值:
```sh
python tools/animegan_picmean.py --dataset data/animedataset/Hayao/style
image_num: 1792
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1792/1792 [00:04<00:00, 444.95it/s]
RGB mean diff
[-4.4346957 -8.665916 13.100612 ]
```
### 1.2.3 测试
测试模型:
```sh
python tools/main.py --config-file configs/animeganv2.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 1.3 结果展示
| 原始图像 | 风格化后图像 |
| ----------------------------------- | ---------------------------------- |
| ![](../../imgs/animeganv2_test.jpg) | ![](../../imgs/animeganv2_res.jpg) |
......@@ -19,3 +19,4 @@ from .realsr_predictor import RealSRPredictor
from .edvr_predictor import EDVRPredictor
from .first_order_predictor import FirstOrderPredictor
from .face_parse_predictor import FaceParsePredictor
from .animegan_predictor import AnimeGANPredictor
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import numpy as np
import cv2
import paddle
from .base_predictor import BasePredictor
from ppgan.datasets.transforms import ResizeToScale
import paddle.vision.transforms as T
from ppgan.models.generators import AnimeGenerator
from ppgan.utils.download import get_path_from_url
class AnimeGANPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
weight_path=None,
use_adjust_brightness=True):
self.output_path = output_path
self.input_size = (256, 256)
self.use_adjust_brightness = use_adjust_brightness
if weight_path is None:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/models/animeganv2_hayao.pdparams'
weight_path = get_path_from_url(vox_cpk_weight_url)
self.weight_path = weight_path
self.generator = self.load_checkpoints()
self.transform = T.Compose([
ResizeToScale((256, 256), 32),
T.Transpose(),
T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
])
def load_checkpoints(self):
generator = AnimeGenerator()
checkpoint = paddle.load(self.weight_path)
generator.set_state_dict(checkpoint['netG'])
generator.eval()
return generator
@staticmethod
def calc_avg_brightness(img):
R = img[..., 0].mean()
G = img[..., 1].mean()
B = img[..., 2].mean()
brightness = 0.299 * R + 0.587 * G + 0.114 * B
return brightness, B, G, R
@staticmethod
def adjust_brightness(dst, src):
brightness1, B1, G1, R1 = AnimeGANPredictor.calc_avg_brightness(src)
brightness2, B2, G2, R2 = AnimeGANPredictor.calc_avg_brightness(dst)
brightness_difference = brightness1 / brightness2
dstf = dst * brightness_difference
dstf = np.clip(dstf, 0, 255)
dstf = np.uint8(dstf)
return dstf
def run(self, image):
image = cv2.cvtColor(cv2.imread(image, flags=cv2.IMREAD_COLOR),
cv2.COLOR_BGR2RGB)
transformed_image = self.transform(image)
anime = (self.generator(paddle.to_tensor(transformed_image[None, ...]))
* 0.5 + 0.5)[0].numpy() * 255
anime = anime.transpose((1, 2, 0))
if anime.shape[:2] != image.shape[:2]:
# to original size
anime = T.resize(anime, image.shape[:2])
if self.use_adjust_brightness:
anime = self.adjust_brightness(anime, image)
else:
anime = anime.astype('uint8')
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
save_path = os.path.join(self.output_path, 'anime.png')
cv2.imwrite(save_path, cv2.cvtColor(anime, cv2.COLOR_RGB2BGR))
return image
......@@ -17,3 +17,4 @@ from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
from .makeup_dataset import MakeupDataset
from .animeganv2_dataset import AnimeGANV2Dataset
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import cv2
import numpy as np
import os.path
from .base_dataset import BaseDataset
from .image_folder import ImageFolder
from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register()
class AnimeGANV2Dataset(BaseDataset):
"""
"""
def __init__(self, cfg):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.style = cfg.style
self.transform_real = build_transforms(self.cfg.transform_real)
self.transform_anime = build_transforms(self.cfg.transform_anime)
self.transform_gray = build_transforms(self.cfg.transform_gray)
self.real_root = os.path.join(self.root, 'train_photo')
self.anime_root = os.path.join(self.root, f'{self.style}', 'style')
self.smooth_root = os.path.join(self.root, f'{self.style}', 'smooth')
self.real = ImageFolder(self.real_root,
transform=self.transform_real,
loader=self.loader)
self.anime = ImageFolder(self.anime_root,
transform=self.transform_anime,
loader=self.loader)
self.anime_gray = ImageFolder(self.anime_root,
transform=self.transform_gray,
loader=self.loader)
self.smooth_gray = ImageFolder(self.smooth_root,
transform=self.transform_gray,
loader=self.loader)
self.sizes = [
len(fold) for fold in [self.real, self.anime, self.smooth_gray]
]
self.size = max(self.sizes)
self.reshuffle()
@staticmethod
def loader(path):
return cv2.cvtColor(cv2.imread(path, flags=cv2.IMREAD_COLOR),
cv2.COLOR_BGR2RGB)
def reshuffle(self):
indexs = []
for cur_size in self.sizes:
x = np.arange(0, cur_size)
np.random.shuffle(x)
if cur_size != self.size:
pad_num = self.size - cur_size
pad = np.random.choice(cur_size, pad_num, replace=True)
x = np.concatenate((x, pad))
np.random.shuffle(x)
indexs.append(x.tolist())
self.indexs = list(zip(*indexs))
def __getitem__(self, index):
try:
index = self.indexs.pop()
except IndexError as e:
self.reshuffle()
index = self.indexs.pop()
real_idx, anime_idx, smooth_idx = index
return {
'real': self.real[real_idx],
'anime': self.anime[anime_idx],
'anime_gray': self.anime_gray[anime_idx],
'smooth_gray': self.smooth_gray[smooth_idx]
}
def __len__(self):
return self.size
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip, Add, ResizeToScale
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from . import functional_cv2 as F_cv2
from paddle.vision.transforms.functional import _is_numpy_image, _is_pil_image
__all__ = ['add']
def add(pic, value):
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(
type(pic)))
if _is_pil_image(pic):
raise NotImplementedError('add not support pil image')
else:
return F_cv2.add(pic, value)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
import numpy as np
def add(image, value):
return np.clip(image + value, 0, 255).astype('uint8')
......@@ -19,6 +19,7 @@ import collections
import numpy as np
import paddle.vision.transforms as T
import ppgan.datasets.transforms.functional as custom_F
import paddle.vision.transforms.functional as F
from .builder import TRANSFORMS
......@@ -35,6 +36,7 @@ TRANSFORMS.register(T.RandomCrop)
TRANSFORMS.register(T.RandomHorizontalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
TRANSFORMS.register(T.Grayscale)
@TRANSFORMS.register()
......@@ -72,3 +74,73 @@ class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
Parameters:
value (List[int]) -- the [r,g,b] value will add to image by pixel wise.
"""
super().__init__(keys=keys)
self.value = value
def _get_params(self, inputs):
params = {}
params['value'] = self.value
return params
def _apply_image(self, image):
return custom_F.add(image, self.params['value'])
@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
interpolation='bilinear',
keys=None):
"""Initialize ResizeToScale Transform
Parameters:
size (List[int]) -- the minimum target size
scale (List[int]) -- the stride scale
interpolation (Optional[str]) -- interpolation method
"""
super().__init__(keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.scale = scale
self.interpolation = interpolation
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
hw = image.shape[:2]
params = {}
params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1],
self.scale)
return params
@staticmethod
def reduce_to_scale(img_hw, min_hw, scale):
im_h, im_w = img_hw
if im_h <= min_hw[0]:
im_h = min_hw[0]
else:
x = im_h % scale
im_h = im_h - x
if im_w < min_hw[1]:
im_w = min_hw[1]
else:
y = im_w % scale
im_w = im_w - y
return (im_h, im_w)
def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
......@@ -20,3 +20,4 @@ from .sr_model import SRModel
from .makeup_model import MakeupModel
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
from paddle import nn
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from ..modules.caffevgg import CaffeVGG19
from ..solver import build_optimizer
from ..modules.init import init_weights
from ..utils.filesystem import load
@MODELS.register()
class AnimeGANV2Model(BaseModel):
def __init__(self, cfg):
"""Initialize the AnimeGANV2 class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(AnimeGANV2Model, self).__init__(cfg)
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD'])
self.pretrained = CaffeVGG19()
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionL1 = nn.L1Loss()
self.criterionHub = nn.SmoothL1Loss()
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
if self.cfg.pretrain_ckpt:
state_dicts = load(self.cfg.pretrain_ckpt)
self.nets['netG'].set_state_dict(state_dicts['netG'])
print('Load pretrained generator from', self.cfg.pretrain_ckpt)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
"""
if self.is_train:
self.real = paddle.to_tensor(input['real'])
self.anime = paddle.to_tensor(input['anime'])
self.anime_gray = paddle.to_tensor(input['anime_gray'])
self.smooth_gray = paddle.to_tensor(input['smooth_gray'])
else:
self.real = paddle.to_tensor(input['A'])
self.image_paths = input['A_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake = self.nets['netG'](self.real) # G(A)
# put items to visual dict
self.visual_items['real'] = self.real
self.visual_items['fake'] = self.fake
def test(self):
self.fake = self.nets['netG'](self.real) # G(A)
# put items to visual dict
self.visual_items['real'] = self.real
self.visual_items['fake'] = self.fake
@staticmethod
def gram(x):
b, c, h, w = x.shape
x_tmp = x.reshape((b, c, (h * w)))
gram = paddle.matmul(x_tmp, x_tmp, transpose_y=True)
return gram / (c * h * w)
def style_loss(self, style, fake):
return self.criterionL1(self.gram(style), self.gram(fake))
def con_sty_loss(self, real, anime, fake):
real_feature_map = self.pretrained(real)
fake_feature_map = self.pretrained(fake)
anime_feature_map = self.pretrained(anime)
c_loss = self.criterionL1(real_feature_map, fake_feature_map)
s_loss = self.style_loss(anime_feature_map, fake_feature_map)
return c_loss, s_loss
@staticmethod
def rgb2yuv(rgb):
kernel = paddle.to_tensor([[0.299, -0.14714119, 0.61497538],
[0.587, -0.28886916, -0.51496512],
[0.114, 0.43601035, -0.10001026]],
dtype='float32')
rgb = paddle.transpose(rgb, (0, 2, 3, 1))
yuv = paddle.matmul(rgb, kernel)
return yuv
@staticmethod
def denormalize(image):
return image * 0.5 + 0.5
def color_loss(self, con, fake):
con = self.rgb2yuv(self.denormalize(con))
fake = self.rgb2yuv(self.denormalize(fake))
return (self.criterionL1(con[:, :, :, 0], fake[:, :, :, 0]) +
self.criterionHub(con[:, :, :, 1], fake[:, :, :, 1]) +
self.criterionHub(con[:, :, :, 2], fake[:, :, :, 2]))
@staticmethod
def variation_loss(image, ksize=1):
dh = image[:, :, :-ksize, :] - image[:, :, ksize:, :]
dw = image[:, :, :, :-ksize] - image[:, :, :, ksize:]
return (paddle.mean(paddle.abs(dh)) + paddle.mean(paddle.abs(dw)))
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
# use conditional GANs; we need to feed both input and output to the discriminator
real_logit = self.nets['netD'](self.anime)
gray_logit = self.nets['netD'](self.anime_gray)
fake_logit = self.nets['netD'](self.fake.detach())
smooth_logit = self.nets['netD'](self.smooth_gray)
d_real_loss = (self.cfg.d_adv_weight * 1.2 *
self.criterionGAN(real_logit, True))
d_gray_loss = (self.cfg.d_adv_weight * 1.2 *
self.criterionGAN(gray_logit, False))
d_fake_loss = (self.cfg.d_adv_weight * 1.2 *
self.criterionGAN(fake_logit, False))
d_blur_loss = (self.cfg.d_adv_weight * 0.8 *
self.criterionGAN(smooth_logit, False))
self.loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss
self.loss_D.backward()
self.losses['d_loss'] = self.loss_D
self.losses['d_real_loss'] = d_real_loss
self.losses['d_fake_loss'] = d_fake_loss
self.losses['d_gray_loss'] = d_gray_loss
self.losses['d_blur_loss'] = d_blur_loss
def backward_G(self):
fake_logit = self.nets['netD'](self.fake)
c_loss, s_loss = self.con_sty_loss(self.real, self.anime_gray,
self.fake)
c_loss = self.cfg.con_weight * c_loss
s_loss = self.cfg.sty_weight * s_loss
tv_loss = self.cfg.tv_weight * self.variation_loss(self.fake)
col_loss = self.cfg.color_weight * self.color_loss(self.real, self.fake)
g_loss = (self.cfg.g_adv_weight * self.criterionGAN(fake_logit, True))
self.loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss
self.loss_G.backward()
self.losses['g_loss'] = self.loss_G
self.losses['c_loss'] = c_loss
self.losses['s_loss'] = s_loss
self.losses['col_loss'] = col_loss
self.losses['tv_loss'] = tv_loss
def optimize_parameters(self):
# compute fake images: G(A)
self.forward()
# update D
self.optimizers['optimizer_D'].clear_grad()
self.backward_D()
self.optimizers['optimizer_D'].step()
# update G
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
@MODELS.register()
class AnimeGANV2PreTrainModel(AnimeGANV2Model):
def backward_G(self):
real_feature_map = self.pretrained(self.real)
fake_feature_map = self.pretrained(self.fake)
init_c_loss = self.criterionL1(real_feature_map, fake_feature_map)
loss = self.cfg.con_weight * init_c_loss
loss.backward()
self.losses['init_c_loss'] = init_c_loss
def optimize_parameters(self):
self.forward()
# update G
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
......@@ -15,3 +15,4 @@
from .nlayers import NLayerDiscriminator
from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator
from .discriminator_animegan import AnimeDiscriminator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import DISCRIMINATORS
from ...modules.utils import spectral_norm
@DISCRIMINATORS.register()
class AnimeDiscriminator(nn.Layer):
def __init__(self, channel: int = 64, nblocks: int = 3) -> None:
super().__init__()
channel = channel // 2
last_channel = channel
f = [
spectral_norm(
nn.Conv2D(3, channel, 3, stride=1, padding=1, bias_attr=False)),
nn.LeakyReLU(0.2)
]
in_h = 256
for i in range(1, nblocks):
f.extend([
spectral_norm(
nn.Conv2D(last_channel,
channel * 2,
3,
stride=2,
padding=1,
bias_attr=False)),
nn.LeakyReLU(0.2),
spectral_norm(
nn.Conv2D(channel * 2,
channel * 4,
3,
stride=1,
padding=1,
bias_attr=False)),
nn.GroupNorm(1, channel * 4),
nn.LeakyReLU(0.2)
])
last_channel = channel * 4
channel = channel * 2
in_h = in_h // 2
self.body = nn.Sequential(*f)
self.head = nn.Sequential(*[
spectral_norm(
nn.Conv2D(last_channel,
channel * 2,
3,
stride=1,
padding=1,
bias_attr=False)),
nn.GroupNorm(1, channel * 2),
nn.LeakyReLU(0.2),
spectral_norm(
nn.Conv2D(
channel * 2, 1, 3, stride=1, padding=1, bias_attr=False))
])
def forward(self, x):
x = self.body(x)
x = self.head(x)
return x
......@@ -18,4 +18,5 @@ from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention
from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator
from .wav2lip import Wav2Lip
from .generater_animegan import AnimeGenerator, AnimeGeneratorLite
from .wav2lip import Wav2Lip
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
class Conv2DNormLReLU(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
bias_attr=False) -> None:
super().__init__()
self.conv = nn.Conv2D(in_channels,
out_channels,
kernel_size,
stride,
padding,
bias_attr=bias_attr)
# NOTE layer norm is crucial for animegan!
self.norm = nn.GroupNorm(1, out_channels)
self.lrelu = nn.LeakyReLU(0.2)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.lrelu(x)
return x
class ResBlock(nn.Layer):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.body = nn.Sequential(
Conv2DNormLReLU(in_channels, out_channels, 1, padding=0),
Conv2DNormLReLU(out_channels, out_channels, 3),
nn.Conv2D(out_channels, out_channels // 2, 1, bias_attr=False))
def forward(self, x0):
x = self.body(x0)
return x0 + x
class InvertedresBlock(nn.Layer):
def __init__(self,
in_channels: int,
expansion: float,
out_channels: int,
bias_attr=False):
super().__init__()
self.in_channels = in_channels
self.expansion = expansion
self.out_channels = out_channels
self.bottle_channels = round(self.expansion * self.in_channels)
self.body = nn.Sequential(
# pw
Conv2DNormLReLU(self.in_channels,
self.bottle_channels,
kernel_size=1,
bias_attr=bias_attr),
# dw
nn.Conv2D(self.bottle_channels,
self.bottle_channels,
kernel_size=3,
stride=1,
padding=0,
groups=self.bottle_channels,
bias_attr=True),
nn.GroupNorm(1, self.bottle_channels),
nn.LeakyReLU(0.2),
# pw & linear
nn.Conv2D(self.bottle_channels,
self.out_channels,
kernel_size=1,
padding=0,
bias_attr=False),
nn.GroupNorm(1, self.out_channels),
)
def forward(self, x0):
x = self.body(x0)
if self.in_channels == self.out_channels:
out = paddle.add(x0, x)
else:
out = x
return x
@GENERATORS.register()
class AnimeGeneratorLite(nn.Layer):
def __init__(self) -> None:
super().__init__()
self.A = nn.Sequential(Conv2DNormLReLU(3, 32, 7, padding=3),
Conv2DNormLReLU(32, 32, stride=2),
Conv2DNormLReLU(32, 32))
self.B = nn.Sequential(Conv2DNormLReLU(32, 64, stride=2),
Conv2DNormLReLU(64, 64), Conv2DNormLReLU(64, 64))
self.C = nn.Sequential(ResBlock(64, 128), ResBlock(64, 128),
ResBlock(64, 128), ResBlock(64, 128))
self.D = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear'),
Conv2DNormLReLU(64, 64), Conv2DNormLReLU(64, 64),
Conv2DNormLReLU(64, 64))
self.E = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear'),
Conv2DNormLReLU(64, 32), Conv2DNormLReLU(32, 32),
Conv2DNormLReLU(32, 32, 7, padding=3))
self.out = nn.Sequential(nn.Conv2D(32, 3, 1, bias_attr=False),
nn.Tanh())
def forward(self, x):
x = self.A(x)
x = self.B(x)
x = self.C(x)
x = self.D(x)
x = self.E(x)
x = self.out(x)
return x
@GENERATORS.register()
class AnimeGenerator(nn.Layer):
def __init__(self) -> None:
super().__init__()
self.A = nn.Sequential(Conv2DNormLReLU(3, 32, 7, padding=3),
Conv2DNormLReLU(32, 64, stride=2),
Conv2DNormLReLU(64, 64))
self.B = nn.Sequential(Conv2DNormLReLU(64, 128, stride=2),
Conv2DNormLReLU(128, 128),
Conv2DNormLReLU(128, 128))
self.C = nn.Sequential(InvertedresBlock(128, 2, 256),
InvertedresBlock(256, 2, 256),
InvertedresBlock(256, 2, 256),
InvertedresBlock(256, 2, 256),
Conv2DNormLReLU(256, 128))
self.D = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear'),
Conv2DNormLReLU(128, 128),
Conv2DNormLReLU(128, 128))
self.E = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear'),
Conv2DNormLReLU(128, 64),
Conv2DNormLReLU(64, 64),
Conv2DNormLReLU(64, 32, 7, padding=3))
self.out = nn.Sequential(nn.Conv2D(32, 3, 1, bias_attr=False),
nn.Tanh())
def forward(self, x):
x = self.A(x)
x = self.B(x)
x = self.C(x)
x = self.D(x)
x = self.E(x)
x = self.out(x)
return x
import paddle
import paddle.nn as nn
import numpy as np
from ppgan.utils.download import get_path_from_url
model_urls = {
'caffevgg19': ('https://paddlegan.bj.bcebos.com/models/vgg19_no_fc.npy',
'8ea1ef2374f8684b6cea9f300849be81')
}
class CaffeVGG19(nn.Layer):
cfg = [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512,
'M', 512, 512, 512, 512, 'M'
]
def __init__(self, output_index: int = 26) -> None:
super().__init__()
arch = 'caffevgg19'
weights_path = get_path_from_url(model_urls[arch][0],
model_urls[arch][1])
data_dict: dict = np.load(weights_path,
encoding='latin1',
allow_pickle=True).item()
self.features = self.make_layers(self.cfg, data_dict)
del data_dict
self.features = nn.Sequential(*self.features.sublayers()[:output_index])
mean = paddle.to_tensor([103.939, 116.779, 123.68])
self.mean = mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
def _process(self, x):
rgb = (x * 0.5 + 0.5) * 255 # value to 255
bgr = paddle.stack((rgb[:, 2, :, :], rgb[:, 1, :, :], rgb[:, 0, :, :]),
1) # rgb to bgr
return bgr - self.mean # vgg norm
def _forward_impl(self, x):
x = self._process(x)
# NOTE get output with out relu activation
x = self.features(x)
return x
def forward(self, x):
return self._forward_impl(x)
@staticmethod
def get_conv_filter(data_dict, name):
return data_dict[name][0]
@staticmethod
def get_bias(data_dict, name):
return data_dict[name][1]
@staticmethod
def get_fc_weight(data_dict, name):
return data_dict[name][0]
def make_layers(self, cfg, data_dict, batch_norm=False) -> nn.Sequential:
layers = []
in_channels = 3
block = 1
number = 1
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2D(kernel_size=2, stride=2)]
block += 1
number = 1
else:
conv2d = nn.Conv2D(in_channels, v, kernel_size=3, padding=1)
""" set value """
weight = paddle.to_tensor(
self.get_conv_filter(data_dict, f'conv{block}_{number}'))
weight = weight.transpose((3, 2, 0, 1))
bias = paddle.to_tensor(
self.get_bias(data_dict, f'conv{block}_{number}'))
conv2d.weight.set_value(weight)
conv2d.bias.set_value(bias)
number += 1
if batch_norm:
layers += [conv2d, nn.BatchNorm2D(v), nn.ReLU()]
else:
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.Sequential(*layers)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm
import argparse
def read_img(image_path):
img = cv2.imread(image_path)
assert len(img.shape) == 3
B = img[..., 0].mean()
G = img[..., 1].mean()
R = img[..., 2].mean()
return B, G, R
def main(dataset):
file_list = glob(os.path.join(dataset, '*.jpg'))
image_num = len(file_list)
print('image_num:', image_num)
B_total = 0
G_total = 0
R_total = 0
for f in tqdm(file_list):
bgr = read_img(f)
B_total += bgr[0]
G_total += bgr[1]
R_total += bgr[2]
B_mean, G_mean, R_mean = B_total / image_num, G_total / image_num, R_total / image_num
mean = (B_mean + G_mean + R_mean) / 3
print('RGB mean diff')
print(
np.asfarray((mean - R_mean, mean - G_mean, mean - B_mean),
dtype='float32'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
help="get the mean values of rgb from dataset",
type=str,
default='')
args = parser.parse_args()
main(args.dataset)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册