未验证 提交 e19a8892 编写于 作者: W wangna11BD 提交者: GitHub

Add stargan doc (#532)

* add stargan doc

* fix config
上级 b6204126
......@@ -23,7 +23,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
## 🚀 Recent Updates
- 🔥 **Latest Release: [PP-MSVSR](./docs/en_US/tutorials/video_super_resolution.md)** 🔥
- 🔥 **Latest Release: [PP-MSVSR](./docs/en_US/tutorials/video_super_resolution.md)** 🔥
- **Video Super Resolution SOTA models**
<div align='center'>
<img src='https://user-images.githubusercontent.com/48054808/144848981-00c6ad21-0702-4381-9544-becb227ed9f0.gif' width='600'/>
......@@ -106,6 +106,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [Video Super Resolution(VSR)](./docs/en_US/tutorials/video_super_resolution.md)
* [StyleGAN2](./docs/en_US/tutorials/styleganv2.md)
* [Pixel2Style2Pixel](./docs/en_US/tutorials/pixel2style2pixel.md)
* [StarGANv2](docs/en_US/tutorials/starganv2.md)
## Composite Application
......
......@@ -121,6 +121,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* 照片动漫化:[AnimeGANv2](./docs/zh_CN/tutorials/animegan.md)
* 人像动漫化:[U-GAT-IT](./docs/zh_CN/tutorials/ugatit.md)
* 人脸卡通化:[Photo2Cartoon](docs/zh_CN/tutorials/photo2cartoon.md)
* 多种风格迁移:[StarGANv2](docs/zh_CN/tutorials/starganv2.md)
* 动作迁移
* 人脸表情迁移:[First Order Motion Model](./docs/zh_CN/tutorials/motion_driving.md)
* 唇形合成:[Wav2Lip](docs/zh_CN/tutorials/wav2lip.md)
......
......@@ -49,11 +49,11 @@ dataset:
size: [*IMAGE_SIZE, *IMAGE_SIZE]
scale: [0.8, 1.0]
ratio: [0.9, 1.1]
interpolation: 'bilinear'
interpolation: 'bilinear'
keys: [image, image, image]
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bilinear'
interpolation: 'bilinear'
keys: [image, image, image]
- name: RandomHorizontalFlip
prob: 0.5
......@@ -130,12 +130,12 @@ optimizer:
weight_decay: 0.0001
validate:
interval: 5000
interval: 3000
save_img: false
log_config:
interval: 5
visiual_interval: 100
interval: 100
visiual_interval: 3000
snapshot_config:
interval: 5
......@@ -52,11 +52,11 @@ dataset:
size: [*IMAGE_SIZE, *IMAGE_SIZE]
scale: [0.8, 1.0]
ratio: [0.9, 1.1]
interpolation: 'bilinear'
interpolation: 'bilinear'
keys: [image, image, image]
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bilinear'
interpolation: 'bilinear'
keys: [image, image, image]
- name: RandomHorizontalFlip
prob: 0.5
......@@ -133,12 +133,12 @@ optimizer:
weight_decay: 0.0001
validate:
interval: 5000
interval: 3000
save_img: false
log_config:
interval: 5
visiual_interval: 100
interval: 100
visiual_interval: 3000
snapshot_config:
interval: 5
# StarGAN V2
## 1 Introduction
[StarGAN V2](https://arxiv.org/pdf/1912.01865.pdf)is an image-to-image translation model published on CVPR2020.
A good image-to-image translation model should learn a mapping between different visual domains while satisfying the following properties: 1) diversity of generated images and 2) scalability over multiple domains. Existing methods address either of the issues, having limited diversity or multiple models for all domains. StarGAN v2 is a single framework that tackles both and shows significantly improved results over the baselines. Experiments on CelebA-HQ and a new animal faces dataset (AFHQ) validate superiority of StarGAN v2 in terms of visual quality, diversity, and scalability.
## 2 How to use
### 2.1 Prepare dataset
The CelebAHQ dataset used by StarGAN V2 can be downloaded from [here](https://www.dropbox.com/s/f7pvjij2xlpff59/celeba_hq.zip?dl=0), and the AFHQ dataset can be downloaded from [here](https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0). Then unzip dataset to the ``PaddleGAN/data`` directory.
The structure of dataset is as following:
```
├── data
├── afhq
| ├── train
| | ├── cat
| | ├── dog
| | └── wild
| └── val
| ├── cat
| ├── dog
| └── wild
└── celeba_hq
├── train
| ├── female
| └── male
└── val
├── female
└── male
```
### 2.2 Train/Test
The example uses the AFHQ dataset as an example. If you want to use the CelebAHQ dataset, you can change the config file.
train model:
```
python -u tools/main.py --config-file configs/starganv2_afhq.yaml
```
test model:
```
python tools/main.py --config-file configs/starganv2_afhq.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
![](https://user-images.githubusercontent.com/79366697/146308440-65259d70-d056-43d4-8cf5-a82530993910.jpg)
## 4 Model Download
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| starganv2_afhq | AFHQ | [starganv2_afhq](https://paddlegan.bj.bcebos.com/models/starganv2_afhq.pdparams)
# References
- 1. [StarGAN v2: Diverse Image Synthesis for Multiple Domains](https://arxiv.org/abs/1912.01865)
```
@inproceedings{choi2020starganv2,
title={StarGAN v2: Diverse Image Synthesis for Multiple Domains},
author={Yunjey Choi and Youngjung Uh and Jaejun Yoo and Jung-Woo Ha},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
# StarGAN V2
## 1 原理介绍
[StarGAN V2](https://arxiv.org/pdf/1912.01865.pdf)是发布在CVPR2020上的一个图像转换模型。
一个好的图像到图像转换模型应该学习不同视觉域之间的映射,同时满足以下属性:1)生成图像的多样性和 2)多个域的可扩展性。 现有方法只解决了其中一个问题,领域的多样性有限或对所有领域用多个模型。 StarGAN V2是一个单一的框架,可以同时解决这两个问题,并在基线上显示出显着改善的结果。 CelebAHQ 和新的动物面孔数据集 (AFHQ) 上的实验验证了StarGAN V2在视觉质量、多样性和可扩展性方面的优势。
## 2 如何使用
### 2.1 数据准备
StarGAN V2使用的CelebAHQ数据集可以从[这里](https://www.dropbox.com/s/f7pvjij2xlpff59/celeba_hq.zip?dl=0)下载,使用的AFHQ数据集可以从[这里](https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0)下载。将数据集下载解压后放到``PaddleGAN/data``文件夹下 。
数据的组成形式为:
```
├── data
├── afhq
| ├── train
| | ├── cat
| | ├── dog
| | └── wild
| └── val
| ├── cat
| ├── dog
| └── wild
└── celeba_hq
├── train
| ├── female
| └── male
└── val
├── female
└── male
```
### 2.2 训练/测试
示例以AFHQ数据集为例。如果您想使用CelebAHQ数据集,可以在换一下配置文件。
训练模型:
```
python -u tools/main.py --config-file configs/starganv2_afhq.yaml
```
测试模型:
```
python tools/main.py --config-file configs/starganv2_afhq.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
![](https://user-images.githubusercontent.com/79366697/146308440-65259d70-d056-43d4-8cf5-a82530993910.jpg)
## 4 模型下载
| 模型 | 数据集 | 下载地址 |
|---|---|---|
| starganv2_afhq | AFHQ | [starganv2_afhq](https://paddlegan.bj.bcebos.com/models/starganv2_afhq.pdparams)
# 参考文献
- 1. [StarGAN v2: Diverse Image Synthesis for Multiple Domains](https://arxiv.org/abs/1912.01865)
```
@inproceedings{choi2020starganv2,
title={StarGAN v2: Diverse Image Synthesis for Multiple Domains},
author={Yunjey Choi and Youngjung Uh and Jaejun Yoo and Jung-Woo Ha},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
......@@ -16,22 +16,6 @@ from ppgan.utils.download import get_path_from_url
FAN_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/wing.pdparams"
class AvgPool2D(nn.Layer):
"""
AvgPool2D
Peplace avg_pool2d because paddle.grad will cause avg_pool2d to report an error when training.
In the future Paddle framework will supports avg_pool2d and remove this class.
"""
def __init__(self):
super(AvgPool2D, self).__init__()
self.filter = paddle.to_tensor([[1, 1], [1, 1]], dtype='float32')
def forward(self, x):
filter = self.filter.unsqueeze(0).unsqueeze(1).tile(
[x.shape[1], 1, 1, 1])
return F.conv2d(x, filter, stride=2, padding=0, groups=x.shape[1]) / 4
class ResBlk(nn.Layer):
def __init__(self,
dim_in,
......@@ -45,6 +29,7 @@ class ResBlk(nn.Layer):
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out)
self.maxpool = nn.AvgPool2D(kernel_size=2)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2D(dim_in, dim_in, 3, 1, 1)
......@@ -63,7 +48,7 @@ class ResBlk(nn.Layer):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = AvgPool2D()(x)
x = self.maxpool(x)
return x
def _residual(self, x):
......@@ -72,7 +57,7 @@ class ResBlk(nn.Layer):
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = AvgPool2D()(x)
x = self.maxpool(x)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册