提交 28c896c5 编写于 作者: L LielinJiang

Merge branch 'master' of https://github.com/PaddlePaddle/PaddleGAN into release/0.1.0

English | [简体中文](./README_cn.md) 简体中文 | [English](./README_en.md)
# PaddleGAN # PaddleGAN
PaddleGAN is an development kit of Generative Adversarial Network based on PaddlePaddle. PaddleGAN 是一个基于飞桨的生成对抗网络开发工具包.
### Image Translation ### 图片变换
![](./docs/imgs/A2B.png) ![](./docs/imgs/A2B.png)
![](./docs/imgs/B2A.png) ![](./docs/imgs/B2A.png)
### Makeup shifter ### 妆容迁移
![](./docs/imgs/makeup_shifter.png) ![](./docs/imgs/makeup_shifter.png)
### Old video restore ### 老视频修复
![](./docs/imgs/color_sr_peking.gif) ![](./docs/imgs/color_sr_peking.gif)
### Super resolution ### 超分辨率
![](./docs/imgs/sr_demo.png) ![](./docs/imgs/sr_demo.png)
### Motion driving ### 动作驱动
![](./docs/imgs/first_order.gif) ![](./docs/imgs/first_order.gif)
Features: 特性:
- Highly Flexible:
Components are designed to be modular. Model architectures, as well as data
preprocess pipelines, can be easily customized with simple configuration
changes.
- Rich applications:
PaddleGAN provides rich of applications, such as image generation, image restore, image colorization, video interpolate, makeup shifter. - 高度的灵活性:
## Install 模块化设计,解耦各个网络组件,开发者轻松搭建、试用各种检测模型及优化策略,快速得到高性能、定制化的算法。
### 1. install paddlepaddle - 丰富的应用:
PaddleGAN work with: PaddleGAN 提供了非常多的应用,比如说图像生成,图像修复,图像上色,视频补帧,人脸妆容迁移等.
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
``` ## 安装
pip install -U paddlepaddle-gpu
```
### 2. install ppgan 请参考[安装文档](./docs/install.md)来进行PaddlePaddle和ppgan的安装
``` ## 数据准备
python -m pip install 'git+https://github.com/PaddlePaddle/PaddleGAN.git' 请参考[数据准备](./docs/data_prepare.md) 来准备对应的数据.
```
Or install it from a local clone
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop" ## 快速开始
``` 训练,预测,推理等请参考 [快速开始](./docs/get_started.md).
## Data Prepare ## 模型教程
Please refer to [data prepare](./docs/data_prepare.md) for dataset preparation.
## Get Start
Please refer [get started](./docs/get_started.md) for the basic usage of PaddleGAN.
## Model tutorial
* [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md) * [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md)
* [PSGAN](./docs/tutorials/psgan.md) * [PSGAN](./docs/tutorials/psgan.md)
* [Video restore](./docs/tutorails/video_restore.md) * [视频修复](./docs/tutorials/video_restore.md)
* [Motion driving](./docs/tutorials/motion_driving.md) * [动作驱动](./docs/tutorials/motion_driving.md)
## 许可证书
本项目的发布受[Apache 2.0 license](LICENSE)许可认证。
## License
PaddleGAN is released under the [Apache 2.0 license](LICENSE).
## Contributing ## 贡献代码
Contributions and suggestions are highly welcomed. Most contributions require you to agree to a [Contributor License Agreement (CLA)](https://cla-assistant.io/PaddlePaddle/PaddleGAN) declaring. 我们非常欢迎你可以为PaddleGAN提供任何贡献和建议。大多数贡献都需要你同意参与者许可协议(CLA)。当你提交拉取请求时,CLA机器人会自动检查你是否需要提供CLA。 只需要按照机器人提供的说明进行操作即可。CLA只需要同意一次,就能应用到所有的代码仓库上。关于更多的流程请参考[贡献指南](docs/CONTRIBUTE.md)
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA. Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
For more, please reference [contribution guidelines](docs/CONTRIBUTE.md).
## External Projects ## 外部项目
External gan projects in the community that base on PaddlePaddle: 外部基于飞桨的生成对抗网络模型
+ [PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN) + [PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN)
[English](./README.md) | 简体中文
# PaddleGAN
PaddleGAN 是一个基于飞桨的生成对抗网络开发工具包.
### 图片变换
![](./docs/imgs/A2B.png)
![](./docs/imgs/B2A.png)
### 妆容迁移
![](./docs/imgs/makeup_shifter.png)
### 老视频修复
![](./docs/imgs/color_sr_peking.gif)
### 超分辨率
![](./docs/imgs/sr_demo.png)
### 动作驱动
![](./docs/imgs/first_order.gif)
特性:
- 高度的灵活性:
模块化设计,解耦各个网络组件,开发者轻松搭建、试用各种检测模型及优化策略,快速得到高性能、定制化的算法。
- 丰富的应用:
PaddleGAN 提供了非常多的应用,比如说图像生成,图像修复,图像上色,视频补帧,人脸妆容迁移等.
## 安装
### 1. 安装 paddlepaddle
PaddleGAN 所需的版本:
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
```
pip install -U paddlepaddle-gpu
```
### 2. 安装ppgan
```
python -m pip install 'git+https://github.com/PaddlePaddle/PaddleGAN.git'
```
或者通过将项目克隆到本地
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
## 数据准备
请参考 [数据准备](./docs/data_prepare.md) 来准备对应的数据.
## 快速开始
训练,预测,推理等请参考 [快速开始](./docs/get_started.md).
## 模型教程
* [Pixel2Pixel and CycleGAN](./docs/tutorals/pix2pix_cyclegan.md)
* [PSGAN](./docs/tutorals/psgan.md)
* [视频修复](./docs/tutorails/video_restore.md)
* [动作驱动](./docs/tutorials/motion_driving.md)
## 许可证书
本项目的发布受[Apache 2.0 license](LICENSE)许可认证。
## 贡献代码
我们非常欢迎你可以为PaddleGAN提供任何贡献和建议。大多数贡献都需要你同意参与者许可协议(CLA)。当你提交拉取请求时,CLA机器人会自动检查你是否需要提供CLA。 只需要按照机器人提供的说明进行操作即可。CLA只需要同意一次,就能应用到所有的代码仓库上。关于更多的流程请参考[贡献指南](docs/CONTRIBUTE.md)
## 外部项目
外部基于飞桨的生成对抗网络模型
+ [PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN)
English | [简体中文](./README.md)
# PaddleGAN
PaddleGAN is an development kit of Generative Adversarial Network based on PaddlePaddle.
### Image Translation
![](./docs/imgs/A2B.png)
![](./docs/imgs/B2A.png)
### Makeup shifter
![](./docs/imgs/makeup_shifter.png)
### Old video restore
![](./docs/imgs/color_sr_peking.gif)
### Super resolution
![](./docs/imgs/sr_demo.png)
### Motion driving
![](./docs/imgs/first_order.gif)
Features:
- Highly Flexible:
Components are designed to be modular. Model architectures, as well as data
preprocess pipelines, can be easily customized with simple configuration
changes.
- Rich applications:
PaddleGAN provides rich of applications, such as image generation, image restore, image colorization, video interpolate, makeup shifter.
## Install
Please refer to [install](./docs/install_en.md).
## Data Prepare
Please refer to [data prepare](./docs/data_prepare_en.md) for dataset preparation.
## Get Start
Please refer [get started](./docs/get_started_en.md) for the basic usage of PaddleGAN.
## Model tutorial
* [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md)
* [PSGAN](./docs/tutorials/psgan_en.md)
* [Video restore](./docs/tutorails/video_restore.md)
* [Motion driving](./docs/tutorials/motion_driving_en.md)
## License
PaddleGAN is released under the [Apache 2.0 license](LICENSE).
## Contributing
Contributions and suggestions are highly welcomed. Most contributions require you to agree to a [Contributor License Agreement (CLA)](https://cla-assistant.io/PaddlePaddle/PaddleGAN) declaring.
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA. Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
For more, please reference [contribution guidelines](docs/CONTRIBUTE.md).
## External Projects
External gan projects in the community that base on PaddlePaddle:
+ [PaddleGAN](https://github.com/PaddlePaddle/PaddleGAN)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -49,6 +49,10 @@ parser.add_argument('--time_step', ...@@ -49,6 +49,10 @@ parser.add_argument('--time_step',
type=float, type=float,
default=0.5, default=0.5,
help='choose the time steps') help='choose the time steps')
parser.add_argument('--remove_duplicates',
action='store_true',
default=False,
help='whether to remove duplicated frames')
# DeepRemaster args # DeepRemaster args
parser.add_argument('--reference_dir', parser.add_argument('--reference_dir',
type=str, type=str,
...@@ -88,7 +92,8 @@ if __name__ == "__main__": ...@@ -88,7 +92,8 @@ if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
predictor = DAINPredictor(args.output, predictor = DAINPredictor(args.output,
weight_path=args.DAIN_weight, weight_path=args.DAIN_weight,
time_step=args.time_step) time_step=args.time_step,
remove_duplicates=args.remove_duplicates)
frames_path, temp_video_path = predictor.run(temp_video_path) frames_path, temp_video_path = predictor.run(temp_video_path)
paddle.disable_static() paddle.disable_static()
elif order == 'DeepRemaster': elif order == 'DeepRemaster':
......
...@@ -6,21 +6,20 @@ lambda_identity: 0.5 ...@@ -6,21 +6,20 @@ lambda_identity: 0.5
model: model:
name: CycleGANModel name: CycleGANModel
defaults: &defaults
norm_type: instance
input_nc: 3
generator: generator:
name: ResnetGenerator name: ResnetGenerator
output_nc: 3 output_nc: 3
n_blocks: 9 n_blocks: 9
ngf: 64 ngf: 64
use_dropout: False use_dropout: False
<<: *defaults norm_type: instance
input_nc: 3
discriminator: discriminator:
name: NLayerDiscriminator name: NLayerDiscriminator
ndf: 64 ndf: 64
n_layers: 3 n_layers: 3
<<: *defaults norm_type: instance
input_nc: 3
gan_mode: lsgan gan_mode: lsgan
dataset: dataset:
......
...@@ -6,21 +6,20 @@ lambda_identity: 0.5 ...@@ -6,21 +6,20 @@ lambda_identity: 0.5
model: model:
name: CycleGANModel name: CycleGANModel
defaults: &defaults
norm_type: instance
input_nc: 3
generator: generator:
name: ResnetGenerator name: ResnetGenerator
output_nc: 3 output_nc: 3
n_blocks: 9 n_blocks: 9
ngf: 64 ngf: 64
use_dropout: False use_dropout: False
<<: *defaults norm_type: instance
input_nc: 3
discriminator: discriminator:
name: NLayerDiscriminator name: NLayerDiscriminator
ndf: 64 ndf: 64
n_layers: 3 n_layers: 3
<<: *defaults norm_type: instance
input_nc: 3
gan_mode: lsgan gan_mode: lsgan
dataset: dataset:
...@@ -39,7 +38,7 @@ dataset: ...@@ -39,7 +38,7 @@ dataset:
size: [286, 286] size: [286, 286]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop - name: RandomCrop
output_size: [256, 256] size: [256, 256]
- name: RandomHorizontalFlip - name: RandomHorizontalFlip
prob: 0.5 prob: 0.5
- name: Transpose - name: Transpose
...@@ -55,8 +54,7 @@ dataset: ...@@ -55,8 +54,7 @@ dataset:
output_nc: 3 output_nc: 3
serial_batches: False serial_batches: False
pool_size: 50 pool_size: 50
transform: transforms:
transform:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 'bicubic' #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
......
import os
import argparse
from ppgan.utils.download import get_path_from_url
CYCLEGAN_URL_ROOT = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
parser = argparse.ArgumentParser(description='download datasets')
parser.add_argument('--name',
type=str,
required=True,
help='dataset name, \
support dataset name: apple2orange, summer2winter_yosemite, \
horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, \
vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, \
ae_photos, cityscapes')
if __name__ == "__main__":
args = parser.parse_args()
data_url = CYCLEGAN_URL_ROOT + args.name + '.zip'
if args.name == 'cityscapes':
data_url = 'https://paddlegan.bj.bcebos.com/datasets/cityscapes.zip'
path = get_path_from_url(data_url)
dst = os.path.join('data', args.name)
print('symlink {} to {}'.format(path, dst))
os.symlink(path, dst)
import os
import argparse
from ppgan.utils.download import get_path_from_url
PIX2PIX_URL_ROOT = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/'
parser = argparse.ArgumentParser(description='download datasets')
parser.add_argument('--name',
type=str,
required=True,
help='dataset name, \
support dataset name: cityscapes, night2day, edges2handbags, \
edges2shoes, facades, maps')
if __name__ == "__main__":
args = parser.parse_args()
data_url = PIX2PIX_URL_ROOT + args.name + '.tar.gz'
path = get_path_from_url(data_url)
dst = os.path.join('data', args.name)
print('symlink {} to {}'.format(path, dst))
os.symlink(path, dst)
# Applications接口说明
ppgan.apps包含超分、插针、上色、换妆、图像动画生成等应用,接口使用简洁,并内置了已训练好的模型,可以直接用来做应用。
## 公共用法
### CPU和GPU的切换
默认情况下,如果是GPU设备、并且安装了PaddlePaddle的GPU环境包,则默认使用GPU进行推理。否则,如果安装的是CPU环境包,则使用CPU进行推理。如果需要手动切换CPU、GPU,可以通过以下方式:
```
import paddle
paddle.set_device('cpu')
#paddle.set_device('gpu')
# from ppgan.apps import DeOldifyPredictor
# deoldify = DeOldifyPredictor()
# deoldify.run("docs/imgs/test_old.jpeg")
```
## ppgan.apps.DeOldifyPredictor
```python
ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32)
```
> 构建DeOldify实例。DeOldify是一个基于GAN的老照片上色模型。该接口可以对图片或视频做上色。建议视频使用mp4格式。
>
> **示例**
>
> ```python
> from ppgan.apps import DeOldifyPredictor
> deoldify = DeOldifyPredictor()
> deoldify.run("docs/imgs/test_old.jpeg")
> ```
> **参数**
>
> > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/DeOldify。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> > - render_factor (int): 图片渲染上色时的缩放因子,图片会缩放到边长为16xrender_factor的正方形, 再上色,例如render_factor默认值为32,输入图片先缩放到(16x32=512) 512x512大小的图片。通常来说,render_factor越小,计算速度越快,颜色看起来也更鲜活。较旧和较低质量的图像通常会因降低渲染因子而受益。渲染因子越高,图像质量越好,但颜色可能会稍微褪色。
### run
```python
run(input)
```
> 构建实例后的执行接口。
> **参数**
>
> > - input (str|np.ndarray|Image.Image): 输入的图片或视频文件。如果是图片,可以是图片的路径、np.ndarray、或PIL.Image类型。如果是视频,只能是视频文件路径。
> >
>
> **返回值**
>
> > - tuple(pred_img(np.array), out_paht(str)): 当属输入时图片时,返回预测后的图片,类型PIL.Image,以及图片的保存的路径。
> > - tuple(frame_path(str), out_path(str)): 当输入为视频时,frame_path为视频每帧上色后保存的图片路径,out_path为上色后视频的保存路径。
### run_image
```python
run_image(img)
```
> 图片上色的接口。
> **参数**
>
> > - img (str|np.ndarray|Image.Image): 输入图片,可以是图片的路径、np.ndarray、或PIL.Image类型。
> >
>
> **返回值**
>
> > - pred_img(PIL.Image): 返回预测后的图片,为PIL.Image类型。
### run_video
```python
run_video(video)
```
> 视频上色的接口。
> **参数**
>
> > - Video (str): 输入视频文件的路径。
>
> **返回值**
>
> > - tuple(frame_path(str), out_path(str)): frame_path为视频每帧上色后保存的图片路径,out_path为上色后视频的保存路径。
## ppgan.apps.DeepRemasterPredictor
```python
ppgan.apps.DeepRemasterPredictor(output='output', weight_path=None, colorization=False, reference_dir=None, mindim=360)
```
> 构建DeepRemasterPredictor实例。DeepRemaster是一个基于GAN的老照片/视频修复、上色模型,该模型可以提供一个参考色的图片作为输入。该接口目前只支持视频输入,建议使用mp4格式。
>
> **示例**
>
> ```
> from ppgan.apps import DeepRemasterPredictor
> deep_remaster = DeepRemasterPredictor()
> deep_remaster.run("docs/imgs/test_old.jpeg")
> ```
>
>
> **参数**
>
> > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/DeepRemaster。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> > - colorization (bool): 是否打开上色功能,默认是False,既不打开,只执行修复功能。
> > - reference_dir(str|None): 打开上色功能时,输入参考色图片路径,也可以不设置参考色图片。
> > - mindim(int): 预测前图片会进行缩放,最小边长度。
### run
```python
run(video_path)
```
> 构建实例后的执行接口。
> **参数**
>
> > - video_path (str): 输入视频文件路径。
> >
> > 返回值
> >
> > - tuple(str, str)): 返回两个str类型,前者是视频上色后每帧图片的保存路径,后者是上色之后的视频保存路径。
## ppgan.apps.RealSRPredictor
```python
ppgan.apps.RealSRPredictor(output='output', weight_path=None)
```
> 构建RealSR实例。RealSR: Real-World Super-Resolution via Kernel Estimation and Noise Injection发表于CVPR 2020 Workshops的基于真实世界图像训练的超分辨率模型。此接口对输入图片或视频做4倍的超分辨率。建议视频使用mp4格式。
>
> **用例**
>
> ```
> from ppgan.apps import RealSRPredictor
> sr = RealSRPredictor()
> sr.run("docs/imgs/test_sr.jpeg")
> ```
> **参数**
>
> > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/RealSR。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
```python
run(video_path)
```
> 构建实例后的执行接口。
> **参数**
>
> > - video_path (str): 输入视频文件路径。
> >
>
> **返回值**
>
> > - tuple(pred_img(np.array), out_paht(str)): 当属输入时图片时,返回预测后的图片,类型PIL.Image,以及图片的保存的路径。
> > - tuple(frame_path(str), out_path(str)): 当输入为视频时,frame_path为超分后视频每帧图片的保存路径,out_path为超分后的视频保存路径。
### run_image
```python
run_image(img)
```
> 图片超分的接口。
> **参数**
>
> > - img (str|np.ndarray|Image.Image): 输入图片,可以是图片的路径、np.ndarray、或PIL.Image类型。
>
> **返回值**
>
> > - pred_img(PIL.Image): 返回预测后的图片,为PIL.Image类型。
### run_video
```python
run_video(video)
```
> 视频超分的接口。
> **参数**
>
> > - Video (str): 输入视频文件的路径。
>
> **返回值**
>
> > - tuple(frame_path(str), out_path(str)): frame_path为超分后视频每帧图片的保存路径,out_path为超分后的视频保存路径。
## ppgan.apps.EDVRPredictor
```python
ppgan.apps.EDVRPredictor(output='output', weight_path=None)
```
> 构建RealSR实例。EDVR: Video Restoration with Enhanced Deformable Convolutional Networks,论文链接: https://arxiv.org/abs/1905.02716 ,是一个针对视频超分的模型。该接口,对视频做2倍的超分。建议视频使用mp4格式。
>
> **示例**
>
> ```
> from ppgan.apps import EDVRPredictor
> sr = EDVRPredictor()
> # 测试一个视频文件
> sr.run("docs/imgs/test.mp4")
> ```
> **参数**
>
> > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/EDVR。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
```python
run(video_path)
```
> 构建实例后的执行接口。
> **参数**
>
> > - video_path (str): 输入视频文件路径。
>
> **返回值**
>
> > - tuple(str, str): 前者超分后的视频每帧图片的保存路径,后者为昨晚超分的视频路径。
## ppgan.apps.DAINPredictor
```python
ppgan.apps.DAINPredictor(output='output', weight_path=Nonetime_step=None, use_gpu=True, key_frame_thread=0remove_duplicates=False)
```
> 构建插针DAIN模型的实例。DAIN: Depth-Aware Video Frame Interpolation,论文链接: https://arxiv.org/abs/1904.00830 ,对视频做插针,获得帧率更高的视频。
>
> **示例**
>
> ```
> from ppgan.apps import DAINPredictor
> dain = DAINPredictor()
> # 测试一个视频文件
> dain.run("docs/imgs/test.mp4")
> ```
> **参数**
>
> > - output_path (str): 设置预测输出的保存路径,默认是output。注意,保存路径为设置output/DAIN。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> > - time_step (float): 帧率变化的倍数为 1./time_step,例如,如果time_step为0.5,则2倍插针,为0.25,则为4倍插针。
> > - use_gpu (bool): 是否使用GPU做预测,默认是True。
> > - remove_duplicates (bool): 是否去除重复帧,默认是False。
```python
run(video_path)
```
> 构建实例后的执行接口。
> **参数**
>
> > - video_path (str): 输入视频文件路径。
>
> **返回值**
>
> > - tuple(str, str): 当输入为视频时,frame_path为视频每帧上色后保存的图片路径,out_path为上色后视频的保存路径。
## ppgan.apps.FirstOrderPredictor
```python
ppgan.apps.FirstOrderPredictor(output='output', weight_path=Noneconfig=None, relative=False, adapt_scale=Falsefind_best_frame=False, best_frame=None)
```
> 构建FirsrOrder模型的实例,此模型用来做Image Animation,既给定一张源图片和一个驱动视频,生成一段视频,其中住体是源图片,动作是驱动视频中的动作。论文是First Order Motion Model for Image Animation,论文链接: https://arxiv.org/abs/2003.00196 。
>
> **示例**
>
> ```
> from ppgan.apps import FirstOrderPredictor
> animate = FirstOrderPredictor()
> # 测试一个视频文件
> animate.run("source.png","driving.mp4")
> ```
> **参数**
>
> > - output_path (str): 设置预测输出的保存路径,默认是output。注意,保存路径为设置output/result.mp4。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> > - config (dict|str|None): 设置模型的参数,可以是字典类型或YML文件,默认值是None,采用的默认的参数。当权重默认是None时,config也需采用默认值None。否则,这里的配置和对应权重保持一致
> > - relative (bool): 使用相对还是绝对关键点坐标,默认是False。
> > - adapt_scale (bool): 是否基于关键点凸包的自适应运动,默认是False。
> > - find_best_frame (bool): 是否从与源图片最匹配的帧开始生成,仅仅适用于人脸应用,需要人脸对齐的库。
> > - best_frame (int): 设置起始帧数,默认是None,从第1帧开始(从1开始计数)。
```python
run(source_imagedriving_video)
```
> 构建实例后的执行接口,预测视频保存位置为output/result.mp4。
> **参数**
>
> > - source_image (str): 输入源图片。
> > - driving_video (str): 输入驱动视频,支持mp4格式。
>
> **返回值**
>
> > 无。
## data prepare ## 数据准备
It is recommended to symlink the dataset root to `$PaddleGAN/data`. 现有的配置默认数据集的路径是在`$PaddleGAN/data`下,目录结构如下图所示。如果你已经下载好数据集了,建议将数据集软链接到 `$PaddleGAN/data`
``` ```
PaddleGAN PaddleGAN
...@@ -28,8 +28,65 @@ PaddleGAN ...@@ -28,8 +28,65 @@ PaddleGAN
``` ```
### cyclegan datasets 如果将数据集放在其他位置,比如 ```your/data/path```
more dataset for cyclegan you can download from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/) 你可以修改配置文件中的 ```dataroot``` 参数:
### pix2pix datasets ```
more dataset for pix2pix you can download from [here](hhttps://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) dataset:
train:
name: PairedDataset
dataroot: your/data/path
num_workers: 4
```
### CycleGAN模型相关的数据集下载
#### 已有的数据集下载
##### 从网页下载
cyclgan模型相关的数据集可以在[这里](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)下载
##### 使用脚本下载
我们在 ```PaddleGAN/data``` 文件夹下提供了一个脚本 ```download_cyclegan_data.py``` 方便下载CycleGAN相关的
数据集。执行如下命令可以下载相关的数据集,目前支持的数据集名称有:apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes。
执行如下命令,可以下载对应的数据集到 ```~/.cache/ppgan``` 并软连接到 ```PaddleGAN/data/``` 下。
```
python data/download_cyclegan_data.py --name horse2zebra
```
#### 使用自己的数据集
如果你使用自己的数据集,需要构造成如下目录的格式。注意 ```xxxA``````xxxB```文件数量,文件内容无需一一对应。
```
custom_datasets
├── testA
├── testB
├── trainA
└── trainB
```
### Pix2Pix相关的数据集下载
#### 已有的数据集下载
##### 从网页下载
pixel2pixel模型相关的数据集可以在[这里](hhttps://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)下载
##### 使用脚本下载
我们在 ```PaddleGAN/data``` 文件夹下提供了一个脚本 ```download_pix2pix_data.py``` 方便下载pix2pix模型相关的数据集。执行如下命令可以下载相关的数据集,目前支持的数据集名称有:apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes。
执行如下命令,可以下载对应的数据集到 ```~/.cache/ppgan``` 并软连接到 ```PaddleGAN/data/``` 下。
```
python data/download_pix2pix_data.py --name cityscapes
```
#### 使用自己的数据集
如果你使用自己的数据集,需要构造成如下目录的格式。同时图片应该制作成下图的样式,即左边为一种风格,另一边为相应转换的风格。
```
facades
├── test
├── train
└── val
```
![](./imgs/1.jpg)
## data prepare
The config will suppose your data put in `$PaddleGAN/data`. You can symlink your datasets to `$PaddleGAN/data`.
```
PaddleGAN
|-- configs
|-- data
| |-- cityscapes
| | ├── test
| | ├── testA
| | ├── testB
| | ├── train
| | ├── trainA
| | └── trainB
| ├── horse2zebra
| | ├── testA
| | ├── testB
| | ├── trainA
| | └── trainB
| └── facades
| ├── test
| ├── train
| └── val
|-- docs
|-- ppgan
|-- tools
```
if you put your datasets on other place,for example ```your/data/path```,
you can also change ```dataroot``` in config file:
```
dataset:
train:
name: PairedDataset
dataroot: your/data/path
num_workers: 4
```
### Datasets of CycleGAN
#### download existed datasets
##### download form website
datasets for CycleGAN you can download from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)
##### download by script
You can use ```download_cyclegan_data.py``` in ```PaddleGAN/data``` to download datasets you wanted. Supported datasets are: apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes。
run following command. Dataset will be downloaded to ```~/.cache/ppgan``` and symlink to ```PaddleGAN/data/``` .
```
python data/download_cyclegan_data.py --name horse2zebra
```
#### custom dataset
Data should be arranged in following way if you use custom dataset.
```
custom_datasets
├── testA
├── testB
├── trainA
└── trainB
```
### Datasets of Pix2Pix
#### download existed datasets
##### download from website
dataset for pix2pix you can download from [here](hhttps://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)
##### download by script
You can use ```download_pix2pix_data.py``` in ```PaddleGAN/data``` to download datasets you wanted. Supported datasets are: apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes.
run following command. Dataset will be downloaded to ```~/.cache/ppgan``` and symlink to ```PaddleGAN/data/``` .
```
python data/download_pix2pix_data.py --name cityscapes
```
#### custom datasets
Data should be arranged in following way if you use custom dataset. And image content shoubld be same with example image.
```
facades
├── test
├── train
└── val
```
![](./imgs/1.jpg)
## Getting started with PaddleGAN ## 快速开始使用PaddleGAN
### Train 注意:
* 开始使用PaddleGAN前请确保已经阅读过[安装文档](./install.md),并根据[数据准备文档](./data_prepare.md)准备好数据集。
* 以下教程以CycleGAN模型在Cityscapes数据集上的训练预测作为示例。
### 训练
#### 单卡训练
``` ```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml
``` ```
#### 参数
continue train from last checkpoint - `--config-file (str)`: 配置文件的路径。
输出的日志,权重,可视化结果会默认保存在```./output_dir```中,可以通过配置文件中的```output_dir```参数修改:
```
output_dir: output_dir
```
保存的文件夹会根据模型名字和时间戳自动生成一个新目录,目录示例如下:
```
output_dir
└── CycleGANModel-2020-10-29-09-21
├── epoch_1_checkpoint.pkl
├── log.txt
└── visual_train
├── epoch001_fake_A.png
├── epoch001_fake_B.png
├── epoch001_idt_A.png
├── epoch001_idt_B.png
├── epoch001_real_A.png
├── epoch001_real_B.png
├── epoch001_rec_A.png
├── epoch001_rec_B.png
├── epoch002_fake_A.png
├── epoch002_fake_B.png
├── epoch002_idt_A.png
├── epoch002_idt_B.png
├── epoch002_real_A.png
├── epoch002_real_B.png
├── epoch002_rec_A.png
└── epoch002_rec_B.png
```
#### 恢复训练
训练过程中默认会保存上一个epoch的checkpoint,方便恢复训练
``` ```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml --resume your_checkpoint_path python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml --resume your_checkpoint_path
``` ```
#### 参数
multiple gpus train: - `--resume (str)`: 用来恢复训练的checkpoint路径。
#### 多卡训练:
``` ```
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/pix2pix_cityscapes.yaml CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/cyclegan_cityscapes.yaml
``` ```
### Evaluate ### 预测
``` ```
python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load your_weight_path python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load your_weight_path
``` ```
#### 参数
- `--evaluate-only`: 是否仅进行预测。
- `--load (str)`: 训练好的权重路径。
## Getting started with PaddleGAN
Note:
* Before starting to use PaddleGAN, please make sure you have read the [install document](./install_en.md), and prepare the dataset according to the [data preparation document](./data_prepare_en.md)
* The following tutorial uses the train and evaluate of the CycleGAN model on the Cityscapes dataset as an example
### Train
#### Train with single gpu
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
#### Args
- `--config-file (str)`: path of config file。
The output log, weight, and visualization result will be saved in ```./output_dir``` by default, which can be modified by the ```output_dir``` parameter in the config file:
```
output_dir: output_dir
```
The saved folder will automatically generate a new directory based on the model name and timestamp. The directory example is as follows:
```
output_dir
└── CycleGANModel-2020-10-29-09-21
├── epoch_1_checkpoint.pkl
├── log.txt
└── visual_train
├── epoch001_fake_A.png
├── epoch001_fake_B.png
├── epoch001_idt_A.png
├── epoch001_idt_B.png
├── epoch001_real_A.png
├── epoch001_real_B.png
├── epoch001_rec_A.png
├── epoch001_rec_B.png
├── epoch002_fake_A.png
├── epoch002_fake_B.png
├── epoch002_idt_A.png
├── epoch002_idt_B.png
├── epoch002_real_A.png
├── epoch002_real_B.png
├── epoch002_rec_A.png
└── epoch002_rec_B.png
```
#### Recovery of training
The checkpoint of the previous epoch will be saved by default during the training process to facilitate the recovery of training
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml --resume your_checkpoint_path
```
#### Args
- `--resume (str)`: path of checkpoint。
#### Train with multiple gpus:
```
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
### evaluate
```
python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load your_weight_path
```
#### Args
- `--evaluate-only`: whether to evaluate only。
- `--load (str)`: path of weight。
## 安装PaddleGAN
### 要求
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
* CUDA >= 9.0
### 1. 安装PaddlePaddle
```
pip install -U paddlepaddle-gpu==2.0.0rc0
```
上面命令会默认安装cuda10.2的包,如果想安装其他cuda版本的包,可以参考下面的表格。
<table class="docutils"><tbody><th width="80"> CUDA </th><th valign="bottom" align="left" width="100">python3.8</th><th valign="bottom" align="left" width="100">python3.7</th><th valign="bottom" align="left" width="100">python3.6</th> <tr><td align="left">10.1</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">10.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">9.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> </tr></tbody></table>
### 2. 安装ppgan
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
按照上述方法安装成功后,本地的修改也会自动同步到ppgan中
## Install PaddleGAN
### requirements
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
* CUDA >= 9.0
### 1. Install PaddlePaddle
```
pip install -U paddlepaddle-gpu==2.0.0rc0
```
Note: command above will install paddle with cuda10.2,if your installed cuda is different, you can choose an proper version to install from table below.
<table class="docutils"><tbody><th width="80"> CUDA </th><th valign="bottom" align="left" width="100">python3.8</th><th valign="bottom" align="left" width="100">python3.7</th><th valign="bottom" align="left" width="100">python3.6</th> <tr><td align="left">10.1</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">10.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">9.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> </tr></tbody></table>
### 2. Install ppgan
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
## to be added # First order motion model
## 1. First order motion model原理
First order motion model的任务是image animation,给定一张源图片,给定一个驱动视频,生成一段视频,其中主角是源图片,动作是驱动视频中的动作。如下图所示,源图像通常包含一个主体,驱动视频包含一系列动作。
![](../imgs/fom_demo.png)
以左上角的人脸表情迁移为例,给定一个源人物,给定一个驱动视频,可以生成一个视频,其中主体是源人物,视频中源人物的表情是由驱动视频中的表情所确定的。通常情况下,我们需要对源人物进行人脸关键点标注、进行表情迁移的模型训练。
但是这篇文章提出的方法只需要在同类别物体的数据集上进行训练即可,比如实现太极动作迁移就用太极视频数据集进行训练,想要达到表情迁移的效果就使用人脸视频数据集voxceleb进行训练。训练好后,我们使用对应的预训练模型就可以达到前言中实时image animation的操作。
## 2. 使用方法
用户可以上传自己准备的视频和图片,并在如下命令中的source_image参数和driving_video参数分别换成自己的图片和视频路径,然后运行如下命令,就可以完成动作表情迁移,程序运行成功后,会在ouput文件夹生成名为result.mp4的视频文件,该文件即为动作迁移后的视频。本项目中提供了原始图片和驱动视频供展示使用。运行的命令如下所示:
`python -u tools/first-order-demo.py --driving_video ./ravel_10.mp4 --source_image ./sudaqiang.png --relative --adapt_scale`
**参数说明:**
- driving_video: 驱动视频,视频中人物的表情动作作为待迁移的对象
- source_image: 原始图片,视频中人物的表情动作将迁移到该原始图片中的人物上
- relative: 指示程序中使用视频和图片中人物关键点的相对坐标还是绝对坐标,建议使用相对坐标,若使用绝对坐标,会导致迁移后人物扭曲变形
- adapt_scale: 根据关键点凸包自适应运动尺度
## 3. 生成结果展示
![](../imgs/first_order.gif)
# Fist order motion model
## 1. First order motion model introduction
First order motion model is to complete the Image animation task, which consists of generating a video sequence so that an object in a source image is animated according to the motion of a driving video. The first order motion framework addresses this problem without using any annotation or prior information about the specific object to animate. Once trained on a set of videos depicting objects of the same category (e.g. faces, human bodies), this method can be applied to any object of this class. To achieve this, the innovative method decouple appearance and motion information using a self-supervised formulation. In addition, to support complex motions, it use a representation consisting of a set of learned keypoints along with their local affine transformations. A generator network models occlusions arising during target motions and combines the appearance extracted from the source image and the motion derived from the driving video.
![](../imgs/fom_demo.png)
## How to use
Users can upload the prepared source image and driving video, then substitute the path of source image and driving video for the `source_image` and `driving_video` parameter in the following running command. It will geneate a video file named `result.mp4` in the `output` folder, which is the animated video file.
`python -u tools/first-order-demo.py --driving_video ./ravel_10.mp4 --source_image ./sudaqiang.png --relative --adapt_scale`
**params:**
- driving_video: driving video, the motion of the driving video is to be migrated.
- source_image: source_image, the image will be animated according to the motion of the driving video.
- relative: indicate whether the relative or absolute coordinates of the key points in the video are used in the program. It is recommended to use relative coordinates. If absolute coordinates are used, the characters will be distorted after animation.
- adapt_scale: adapt movement scale based on convex hull of keypoints.
## 3. Animation results
![](../imgs/first_order.gif)
## to be added # 1 Pix2pix
## 1.1 Principle
Pix2pix uses paired images for image translation, which has two different styles of the same image as input, can be used for style transfer. Pix2pix is encouraged by cGAN, cGAN inputs a noisy image and a condition as the supervision information to the generation network, pix2pix uses another style of image as the supervision information input into the generation network, so the fake image is related to another style of image which is input as supervision information, thus realizing the process of image translation.
## 1.2 How to use
### 1.2.1 Prepare Datasets
Paired datasets used by Pix2pix can be download from [here](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/)
For example, the structure of facades is as following:
```
facades
├── test
├── train
└── val
```
You can download from wget, download facades from wget for example:
```
wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/facades.zip --no-check-certificate
```
### 1.2.2 Train/Test
Datasets used in example is facades, you can change it to your own dataset in the config file.
Train a model:
```
python -u tools/main.py --config-file configs/pix2pix_facades.yaml
```
Test the model:
```
python tools/main.py --config-file configs/pix2pix_facades.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 1.3 Results
![](../imgs/horse2zebra.png)
[model download](TODO)
# 2 CycleGAN
## 2.1 Principle
CycleGAN uses unpaired pictures for image translation, input two different images with different styles, and automatically perform style transfer. CycleGAN consists of two generators and two discriminators, generator A is inputting images of style A and outputting images of style B, generator B is inputting images of style B and outputting images of style A. The biggest difference between CycleGAN and pix2pix is that CycleGAN can realize image translation without establishing a one-to-one mapping between the source domain and the target domain.
![](../imgs/cyclegan.png)
## 2.2 How to use
### 2.2.1 Prepare Datasets
Unpair datasets used by CycleGAN can be download from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)
For example, the structure of cityscapes is as following:
```
cityscapes
├── test
├── testA
├── testB
├── train
├── trainA
└── trainB
```
You can download from wget, download facades from wget for example:
```
wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz --no-check-certificate
```
### 2.2.2 Train/Test
Datasets used in example is cityscapes, you can change it to your own dataset in the config file.
Train a model:
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
Test the model:
```
python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 2.3 Results
![](../imgs/A2B.png)
[model download](TODO)
# References
1. [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)
2. [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
# 1 Pix2pix
## 1.1 原理介绍
Pix2pix利用成对的图片进行图像翻译,即输入为同一张图片的两种不同风格,可用于进行风格迁移。Pix2pix是在cGAN的基础上进行改进的,cGAN的生成网络不仅会输入一个噪声图片,同时还会输入一个条件作为监督信息,pix2pix则是把另外一种风格的图像作为监督信息输入生成网络中,这样生成的fake图像就会和作为监督信息的另一种风格的图像相关,从而实现了图像翻译的过程。
![](../imgs/pix2pix.png)
## 1.2 如何使用
### 1.2.1 数据准备
Pix2pix使用成对数据作为训练数据,训练数据可以从[这里](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/)下载。
例如,pix2pix所使用的facades数据的组成形式为:
```
facades
├── test
├── train
└── val
```
也可以通过wget的方式进行数据下载,例如facades数据集的下载方式为:
```
wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz --no-check-certificate
```
### 1.2.2 训练/测试
示例以facades数据为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。
训练模型:
```
python -u tools/main.py --config-file configs/pix2pix_facades.yaml
```
测试模型:
```
python tools/main.py --config-file configs/pix2pix_facades.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 1.3 结果展示
![](../imgs/horse2zebra.png)
[模型下载](TODO)
# 2 CycleGAN
## 2.1 原理介绍
CycleGAN可以利用非成对的图片进行图像翻译,即输入为两种不同风格的不同图片,自动进行风格转换。CycleGAN由两个生成网络和两个判别网络组成,生成网络A是输入A类风格的图片输出B类风格的图片,生成网络B是输入B类风格的图片输出A类风格的图片。CycleGAN和pix2pix最大的不同就是CycleGAN在源域和目标域之间无需建立数据间一对一的映射就可以实现图像翻译。
## 2.2 如何使用
### 2.2.1 数据准备
CycleGAN使用的是非成对的数据,训练数据可以从[这里](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)下载。
例如,cycleGAN所使用的cityscapes数据的组成形式为:
```
cityscapes
├── test
├── testA
├── testB
├── train
├── trainA
└── trainB
```
也可以通过wget的方式进行数据下载,例如facades数据集的下载方式为:
```
wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/facades.zip --no-check-certificate
```
### 2.2.2 训练/测试
示例以cityscapes数据为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。
训练模型:
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
测试模型:
```
python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 2.3 结果展示
![](../imgs/A2B.png)
[模型下载](TODO)
# 参考:
1. [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)
2. [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
## to be added # PSGAN
## 1. PSGAN原理
PSGAN模型的任务是妆容迁移, 即将任意参照图像上的妆容迁移到不带妆容的源图像上。很多人像美化应用都需要这种技术。近来的一些妆容迁移方法大都基于生成对抗网络(GAN)。它们通常采用 CycleGAN 的框架,并在两个数据集上进行训练,即无妆容图像和有妆容图像。但是,现有的方法存在一个局限性:只在正面人脸图像上表现良好,没有为处理源图像和参照图像之间的姿态和表情差异专门设计模块。PSGAN是一种全新的姿态稳健可感知空间的生生成对抗网络。PSGAN 主要分为三部分:妆容提炼网络(MDNet)、注意式妆容变形(AMM)模块和卸妆-再化妆网络(DRNet)。这三种新提出的模块能让 PSGAN 具备上述的完美妆容迁移模型所应具备的能力。
![](../imgs/psgan_arc.png)
## 2. 使用方法
### 2.1 测试
运行如下命令,就可以完成妆容迁移,程序运行成功后,会在当前文件夹生成妆容迁移后的图片文件。本项目中提供了原始图片和参考供展示使用,具体命令如下所示:
```
cd applications/
python tools/ps_demo.py \
--config-file configs/makeup.yaml \
--model_path /your/model/path \
--source_path /your/source/image/path \
--reference_dir /your/ref/image/path
```
**参数说明:**
- config-file: PSGAN网络到参数配置文件,格式为yaml
- model_path: 训练完成保存下来网络权重文件的路径
- source_path: 未化妆的原始图片文件全路径,包含图片文件名字
- reference_dir: 化妆的参考图片文件路径,不包含图片文件名字
### 2.2 训练
1. 从百度网盘下载原始换妆数据[data](https://pan.baidu.com/s/1ZF-DN9PvbBteOSfQodWnyw)(密码:rtdd)到PaddleGAN文件夹, 并解压
2. 下载landmarks数据[lmks](https://paddlegan.bj.bcebos.com/landmarks.tar),并解压
3. 运行如下命令进行文件夹及文件替换:
```
mv landmarks/makeup MT-Dataset/landmarks/makeup
mv landmarks/non-makeup MT-Dataset/landmarks/non-makeup
mv landmarks/train_makeup.txt MT-Dataset/makeup.txt
mv tlandmarks/train_non-makeup.txt MT-Dataset/non-makeup.txt
```
最后数据集目录如下所示:
```
data
├── images
│   ├── makeup
│   └── non-makeup
├── landmarks
│   ├── makeup
│   └── non-makeup
├── train_makeup.txt
├── train_non-makeup.txt
├── segs
│   ├── makeup
│   └── non-makeup
```
4. `python tools/main.py --config-file configs/makeup.yaml` ,训练参数设置参考makeup.yaml.
单卡batch_size=1训练部分log如下所示:
```
[10/29 05:39:40] ppgan.engine.trainer INFO: Epoch: 0, iters: 0 lr: 0.000200 D_A: 0.448 G_A: 0.973 rec: 1.258 idt: 0.624 D_B: 0.436 G_B: 0.889 G_A_his: 0.402 G_B_his: 0.472 G_bg_consis: 0.030 A_vgg: 0.027 B_vgg: 0.040 reader cost: 2.45463s batch cost: 4.20075s
[10/29 05:40:00] ppgan.engine.trainer INFO: Epoch: 0, iters: 10 lr: 0.000200 D_A: 0.200 G_A: 0.488 rec: 0.954 idt: 0.539 D_B: 0.179 G_B: 0.767 G_A_his: 0.224 G_B_his: 0.266 G_bg_consis: 0.033 A_vgg: 0.019 B_vgg: 0.026 reader cost: 0.55506s batch cost: 1.95968s
[10/29 05:40:22] ppgan.engine.trainer INFO: Epoch: 0, iters: 20 lr: 0.000200 D_A: 0.340 G_A: 0.339 rec: 1.293 idt: 0.698 D_B: 0.124 G_B: 0.174 G_A_his: 0.302 G_B_his: 0.233 G_bg_consis: 0.061 A_vgg: 0.032 B_vgg: 0.045 reader cost: 0.74937s batch cost: 2.13529s
[10/29 05:40:42] ppgan.engine.trainer INFO: Epoch: 0, iters: 30 lr: 0.000200 D_A: 0.238 G_A: 0.276 rec: 0.907 idt: 0.449 D_B: 0.324 G_B: 0.292 G_A_his: 0.263 G_B_his: 0.380 G_bg_consis: 0.029 A_vgg: 0.040 B_vgg: 0.049 reader cost: 0.69248s batch cost: 2.06999s
[10/29 05:41:03] ppgan.engine.trainer INFO: Epoch: 0, iters: 40 lr: 0.000200 D_A: 0.236 G_A: 0.111 rec: 0.865 idt: 0.470 D_B: 0.237 G_B: 0.465 G_A_his: 0.289 G_B_his: 0.211 G_bg_consis: 0.021 A_vgg: 0.042 B_vgg: 0.049 reader cost: 0.65904s batch cost: 2.07197s
[10/29 05:41:23] ppgan.engine.trainer INFO: Epoch: 0, iters: 50 lr: 0.000200 D_A: 0.341 G_A: 0.073 rec: 0.698 idt: 0.424 D_B: 0.153 G_B: 0.731 G_A_his: 0.198 G_B_his: 0.180 G_bg_consis: 0.019 A_vgg: 0.032 B_vgg: 0.047 reader cost: 0.52772s batch cost: 1.92949s
[10/29 05:41:43] ppgan.engine.trainer INFO: Epoch: 0, iters: 60 lr: 0.000200 D_A: 0.267 G_A: 0.475 rec: 0.843 idt: 0.462 D_B: 0.266 G_B: 0.534 G_A_his: 0.259 G_B_his: 0.219 G_bg_consis: 0.024 A_vgg: 0.031 B_vgg: 0.041 reader cost: 0.58212s batch cost: 2.02212s
[10/29 05:42:03] ppgan.engine.trainer INFO: Epoch: 0, iters: 70 lr: 0.000200 D_A: 0.116 G_A: 0.298 rec: 0.983 idt: 0.543 D_B: 0.097 G_B: 0.233 G_A_his: 0.210 G_B_his: 0.169 G_bg_consis: 0.046 A_vgg: 0.028 B_vgg: 0.034 reader cost: 0.56367s batch cost: 1.97049s
[10/29 05:42:23] ppgan.engine.trainer INFO: Epoch: 0, iters: 80 lr: 0.000200 D_A: 0.325 G_A: 0.339 rec: 0.744 idt: 0.417 D_B: 0.292 G_B: 0.310 G_A_his: 0.189 G_B_his: 0.206 G_bg_consis: 0.016 A_vgg: 0.029 B_vgg: 0.034 reader cost: 0.60760s batch cost: 2.04126s
[10/29 05:42:43] ppgan.engine.trainer INFO: Epoch: 0, iters: 90 lr: 0.000200 D_A: 0.177 G_A: 0.308 rec: 0.970 idt: 0.494 D_B: 0.199 G_B: 0.813 G_A_his: 0.116 G_B_his: 0.153 G_bg_consis: 0.036 A_vgg: 0.019 B_vgg: 0.042 reader cost: 0.62142s batch cost: 1.96606s
[10/29 05:43:03] ppgan.engine.trainer INFO: Epoch: 0, iters: 100 lr: 0.000200 D_A: 0.178 G_A: 0.382 rec: 1.358 idt: 0.607 D_B: 0.265 G_B: 0.405 G_A_his: 0.086 G_B_his: 0.161 G_bg_consis: 0.060 A_vgg: 0.025 B_vgg: 0.047 reader cost: 0.63939s batch cost: 2.00111s
```
注意:训练时makeup.yaml文件中`isTrain`参数值为`True`, 测试时修改该参数值为`False` .
### 2.3 模型
Model|Dataset|BatchSize|Inference speed|Download
---|:--:|:--:|:--:|:--:
PSGAN|MT-Dataset| 1 | 1.9s(GPU:P40) | [model]()
## 3. 妆容迁移结果展示
![](../imgs/makeup_shifter.png)
# PSGAN
## 1. PSGAN introduction
This paper is to address the makeup transfer task, which aims to transfer the makeup from a reference image to a source image. Existing methods have achieved promising progress in constrained scenarios, but transferring between images with large pose and expression differences is still challenging. To address these issues, we propose Pose and expression robust Spatial-aware GAN (PSGAN). It first utilizes Makeup Distill Network to disentangle the makeup of the reference image as two spatial-aware makeup matrices. Then, Attentive Makeup Morphing module is introduced to specify how the makeup of a pixel in the source image is morphed from the reference image. With the makeup matrices and the source image, Makeup Apply Network is used to perform makeup transfer.
![](../imgs/psgan_arc.png)
## 2. How to use
### 2.1 Test
Running the following command to complete the makeup transfer task. It will geneate the transfered image in the current path when the program running sucessfully.
```
cd applications
python tools/ps_demo.py \
--config-file configs/makeup.yaml \
--model_path /your/model/path \
--source_path /your/source/image/path \
--reference_dir /your/ref/image/path
```
**params:**
- config-file: PSGAN network configuration file, yaml format
- model_path: Saved model weight path
- source_path: Full path of the non-makeup image file, including the image file name
- reference_dir: Path of the make_up iamge file, don't including the image file name
### 2.2 Training
1. Downloading the original makeup transfer [data](https://pan.baidu.com/s/1ZF-DN9PvbBteOSfQodWnyw)(Password:rtdd) to the PaddleGAN folder, and uncompress it.
2. Downloading the landmarks [data](https://paddlegan.bj.bcebos.com/landmarks.tar), and uncompress it
3. Runnint the following command to substitute files:
```
mv landmarks/makeup MT-Dataset/landmarks/makeup
mv landmarks/non-makeup MT-Dataset/landmarks/non-makeup
mv landmarks/train_makeup.txt MT-Dataset/makeup.txt
mv tlandmarks/train_non-makeup.txt MT-Dataset/non-makeup.txt
```
The final data directory should be looked like:
```
data
├── images
│ ├── makeup
│ └── non-makeup
├── landmarks
│ ├── makeup
│ └── non-makeup
├── train_makeup.txt
├── train_non-makeup.txt
├── segs
│ ├── makeup
│ └── non-makeup
```
2. `python tools/main.py --config-file configs/makeup.yaml`
The training log looks like:
```
[10/29 05:39:40] ppgan.engine.trainer INFO: Epoch: 0, iters: 0 lr: 0.000200 D_A: 0.448 G_A: 0.973 rec: 1.258 idt: 0.624 D_B: 0.436 G_B: 0.889 G_A_his: 0.402 G_B_his: 0.472 G_bg_consis: 0.030 A_vgg: 0.027 B_vgg: 0.040 reader cost: 2.45463s batch cost: 4.20075s
[10/29 05:40:00] ppgan.engine.trainer INFO: Epoch: 0, iters: 10 lr: 0.000200 D_A: 0.200 G_A: 0.488 rec: 0.954 idt: 0.539 D_B: 0.179 G_B: 0.767 G_A_his: 0.224 G_B_his: 0.266 G_bg_consis: 0.033 A_vgg: 0.019 B_vgg: 0.026 reader cost: 0.55506s batch cost: 1.95968s
[10/29 05:40:22] ppgan.engine.trainer INFO: Epoch: 0, iters: 20 lr: 0.000200 D_A: 0.340 G_A: 0.339 rec: 1.293 idt: 0.698 D_B: 0.124 G_B: 0.174 G_A_his: 0.302 G_B_his: 0.233 G_bg_consis: 0.061 A_vgg: 0.032 B_vgg: 0.045 reader cost: 0.74937s batch cost: 2.13529s
[10/29 05:40:42] ppgan.engine.trainer INFO: Epoch: 0, iters: 30 lr: 0.000200 D_A: 0.238 G_A: 0.276 rec: 0.907 idt: 0.449 D_B: 0.324 G_B: 0.292 G_A_his: 0.263 G_B_his: 0.380 G_bg_consis: 0.029 A_vgg: 0.040 B_vgg: 0.049 reader cost: 0.69248s batch cost: 2.06999s
[10/29 05:41:03] ppgan.engine.trainer INFO: Epoch: 0, iters: 40 lr: 0.000200 D_A: 0.236 G_A: 0.111 rec: 0.865 idt: 0.470 D_B: 0.237 G_B: 0.465 G_A_his: 0.289 G_B_his: 0.211 G_bg_consis: 0.021 A_vgg: 0.042 B_vgg: 0.049 reader cost: 0.65904s batch cost: 2.07197s
[10/29 05:41:23] ppgan.engine.trainer INFO: Epoch: 0, iters: 50 lr: 0.000200 D_A: 0.341 G_A: 0.073 rec: 0.698 idt: 0.424 D_B: 0.153 G_B: 0.731 G_A_his: 0.198 G_B_his: 0.180 G_bg_consis: 0.019 A_vgg: 0.032 B_vgg: 0.047 reader cost: 0.52772s batch cost: 1.92949s
[10/29 05:41:43] ppgan.engine.trainer INFO: Epoch: 0, iters: 60 lr: 0.000200 D_A: 0.267 G_A: 0.475 rec: 0.843 idt: 0.462 D_B: 0.266 G_B: 0.534 G_A_his: 0.259 G_B_his: 0.219 G_bg_consis: 0.024 A_vgg: 0.031 B_vgg: 0.041 reader cost: 0.58212s batch cost: 2.02212s
[10/29 05:42:03] ppgan.engine.trainer INFO: Epoch: 0, iters: 70 lr: 0.000200 D_A: 0.116 G_A: 0.298 rec: 0.983 idt: 0.543 D_B: 0.097 G_B: 0.233 G_A_his: 0.210 G_B_his: 0.169 G_bg_consis: 0.046 A_vgg: 0.028 B_vgg: 0.034 reader cost: 0.56367s batch cost: 1.97049s
[10/29 05:42:23] ppgan.engine.trainer INFO: Epoch: 0, iters: 80 lr: 0.000200 D_A: 0.325 G_A: 0.339 rec: 0.744 idt: 0.417 D_B: 0.292 G_B: 0.310 G_A_his: 0.189 G_B_his: 0.206 G_bg_consis: 0.016 A_vgg: 0.029 B_vgg: 0.034 reader cost: 0.60760s batch cost: 2.04126s
[10/29 05:42:43] ppgan.engine.trainer INFO: Epoch: 0, iters: 90 lr: 0.000200 D_A: 0.177 G_A: 0.308 rec: 0.970 idt: 0.494 D_B: 0.199 G_B: 0.813 G_A_his: 0.116 G_B_his: 0.153 G_bg_consis: 0.036 A_vgg: 0.019 B_vgg: 0.042 reader cost: 0.62142s batch cost: 1.96606s
[10/29 05:43:03] ppgan.engine.trainer INFO: Epoch: 0, iters: 100 lr: 0.000200 D_A: 0.178 G_A: 0.382 rec: 1.358 idt: 0.607 D_B: 0.265 G_B: 0.405 G_A_his: 0.086 G_B_his: 0.161 G_bg_consis: 0.060 A_vgg: 0.025 B_vgg: 0.047 reader cost: 0.63939s batch cost: 2.00111s
```
Notation: In train phase, the `isTrain` value in makeup.yaml file is `True`, but in test phase, its value should be modified as `False`.
### 2.3 Model
Model|Dataset|BatchSize|Inference speed|Download
---|:--:|:--:|:--:|:--:
PSGAN|MT-Dataset| 1 | 1.9s(GPU:P40) | [model]()
## 3. Result
![](../imgs/makeup_shifter.png)
## to be added ## 老视频修复
老视频往往具有帧数少,无色彩,分辨率低等特点。于是针对这些特点,我们使用补帧,上色,超分等模型对视频进行修复。
### 使用applications中的video-enhance.py工具进行快速开始视频修复
```
cd applications
python tools/video-enhance.py --input you_video_path.mp4 --proccess_order DAIN DeOldify EDVR --output output_dir
```
#### 参数
- `--input (str)`: 输入的视频路径。
- `--output (str)`: 输出的视频路径。
- `--proccess_order`: 调用的模型名字和顺序,比如输入为 `DAIN DeOldify EDVR`,则会顺序调用 `DAINPredictor` `DeOldifyPredictor` `EDVRPredictor`
#### 效果展示
![](../imgs/color_sr_peking.gif)
### 快速体验
我们在ai studio制作了一个[ai studio 老北京视频修复教程](https://aistudio.baidu.com/aistudio/projectdetail/1161285)
### 注意事项
* 在使用本教程前,请确保您已经[安装完paddle和ppgan]()。
* 本教程的所有命令都基于PaddleGAN/applications主目录进行执行。
* 各个模型耗时较长,尤其使超分辨率模型,建议输入的视频分辨率低一些,时长短一些。
* 需要运行在gpu环境上
### ppgan提供的可用于视频修复的预测api简介
可以根据要修复的视频的特点,使用不同的模型与参数
### 补帧模型DAIN
DAIN 模型通过探索深度的信息来显式检测遮挡。并且开发了一个深度感知的流投影层来合成中间流。在视频补帧方面有较好的效果。
![](./imgs/dain_network.png)
```
ppgan.apps.DAINPredictor(
output='output',
weight_path=None,
time_step=None,
use_gpu=True,
remove_duplicates=False)
```
#### 参数
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `time_step (int)`: 补帧的时间系数,如果设置为0.5,则原先为每秒30帧的视频,补帧后变为每秒60帧。
- `remove_duplicates (bool,可选的)`: 是否删除重复帧,默认值:`False`.
### 上色模型DeOldifyPredictor
DeOldify 采用自注意力机制的生成对抗网络,生成器是一个U-NET结构的网络。在图像的上色方面有着较好的效果。
![](./imgs/deoldify_network.png)
```
ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32)
```
#### 参数
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `render_factor (int)`: 会将该参数乘以16后作为输入帧的resize的值,如果该值设置为32,
则输入帧会resize到(32 * 16, 32 * 16)的尺寸再输入到网络中。
### 上色模型DeepRemasterPredictor
DeepRemaster 模型基于时空卷积神经网络和自注意力机制。并且能够根据输入的任意数量的参考帧对图片进行上色。
![](./imgs/remaster_network.png)
```
ppgan.apps.DeepRemasterPredictor(
output='output',
weight_path=None,
colorization=False,
reference_dir=None,
mindim=360):
```
#### 参数
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `colorization (bool)`: 是否对输入视频上色,如果选项设置为 `True` ,则参考帧的文件夹路径也必须要设置。默认值:`False`
- `reference_dir (bool)`: 参考帧的文件夹路径。默认值:`None`
- `mindim (bool)`: 输入帧重新resize后的短边的大小。默认值:360。
### 超分辨率模型RealSRPredictor
RealSR模型通过估计各种模糊内核以及实际噪声分布,为现实世界的图像设计一种新颖的真实图片降采样框架。基于该降采样框架,可以获取与真实世界图像共享同一域的低分辨率图像。并且提出了一个旨在提高感知度的真实世界超分辨率模型。对合成噪声数据和真实世界图像进行的大量实验表明,该模型能够有效降低了噪声并提高了视觉质量。
![](./imgs/realsr_network.png)
```
ppgan.apps.RealSRPredictor(output='output', weight_path=None)
```
#### 参数
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
-
### 超分辨率模型EDVRPredictor
EDVR模型提出了一个新颖的视频具有增强可变形卷积的还原框架:第一,为了处理大动作而设计的一个金字塔,级联和可变形(PCD)对齐模块,使用可变形卷积以从粗到精的方式在特征级别完成对齐;第二,提出时空注意力机制(TSA)融合模块,在时间和空间上都融合了注意机制,用以增强复原的功能。
EDVR模型是一个基于连续帧的超分模型,能够有效利用帧间的信息,速度比RealSR模型快。
![](./imgs/edvr_network.png)
```
ppgan.apps.EDVRPredictor(output='output', weight_path=None)
```
#### 参数
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dain_predictor import DAINPredictor from .dain_predictor import DAINPredictor
from .deepremaster_predictor import DeepRemasterPredictor from .deepremaster_predictor import DeepRemasterPredictor
from .deoldify_predictor import DeOldifyPredictor from .deoldify_predictor import DeOldifyPredictor
......
...@@ -22,7 +22,7 @@ from imageio import imread, imsave ...@@ -22,7 +22,7 @@ from imageio import imread, imsave
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import video2frames, frames2video from ppgan.utils.video import video2frames, frames2video
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
...@@ -32,20 +32,18 @@ DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar' ...@@ -32,20 +32,18 @@ DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'
class DAINPredictor(BasePredictor): class DAINPredictor(BasePredictor):
def __init__(self, def __init__(self,
output_path='output', output='output',
weight_path=None, weight_path=None,
time_step=None, time_step=None,
use_gpu=True, use_gpu=True,
key_frame_thread=0.,
remove_duplicates=False): remove_duplicates=False):
self.output_path = os.path.join(output_path, 'DAIN') self.output_path = os.path.join(output, 'DAIN')
if weight_path is None: if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__)) weight_path = get_path_from_url(DAIN_WEIGHT_URL)
weight_path = get_path_from_url(DAIN_WEIGHT_URL, cur_path)
self.weight_path = weight_path self.weight_path = weight_path
self.time_step = time_step self.time_step = time_step
self.key_frame_thread = key_frame_thread self.key_frame_thread = 0
self.remove_duplicates = remove_duplicates self.remove_duplicates = remove_duplicates
self.build_inference_model() self.build_inference_model()
...@@ -134,15 +132,15 @@ class DAINPredictor(BasePredictor): ...@@ -134,15 +132,15 @@ class DAINPredictor(BasePredictor):
img_first = imread(first) img_first = imread(first)
img_second = imread(second) img_second = imread(second)
'''--------------Frame change test------------------------''' '''--------------Frame change test------------------------'''
img_first_gray = np.dot(img_first[..., :3], [0.299, 0.587, 0.114]) #img_first_gray = np.dot(img_first[..., :3], [0.299, 0.587, 0.114])
img_second_gray = np.dot(img_second[..., :3], [0.299, 0.587, 0.114]) #img_second_gray = np.dot(img_second[..., :3], [0.299, 0.587, 0.114])
img_first_gray = img_first_gray.flatten(order='C') #img_first_gray = img_first_gray.flatten(order='C')
img_second_gray = img_second_gray.flatten(order='C') #img_second_gray = img_second_gray.flatten(order='C')
corr = np.corrcoef(img_first_gray, img_second_gray)[0, 1] #corr = np.corrcoef(img_first_gray, img_second_gray)[0, 1]
key_frame = False #key_frame = False
if corr < self.key_frame_thread: #if corr < self.key_frame_thread:
key_frame = True # key_frame = True
'''-------------------------------------------------------''' '''-------------------------------------------------------'''
X0 = img_first.astype('float32').transpose((2, 0, 1)) / 255 X0 = img_first.astype('float32').transpose((2, 0, 1)) / 255
......
...@@ -22,7 +22,7 @@ from skimage import color ...@@ -22,7 +22,7 @@ from skimage import color
import paddle import paddle
from ppgan.models.generators.remaster import NetworkR, NetworkC from ppgan.models.generators.remaster import NetworkR, NetworkC
from paddle.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
DEEPREMASTER_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' DEEPREMASTER_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
...@@ -77,8 +77,7 @@ class DeepRemasterPredictor(BasePredictor): ...@@ -77,8 +77,7 @@ class DeepRemasterPredictor(BasePredictor):
self.mindim = mindim self.mindim = mindim
if weight_path is None: if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__)) weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL)
weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL, cur_path)
self.weight_path = weight_path self.weight_path = weight_path
......
...@@ -20,7 +20,7 @@ from PIL import Image ...@@ -20,7 +20,7 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import paddle import paddle
from paddle.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators.deoldify import build_model from ppgan.models.generators.deoldify import build_model
...@@ -36,8 +36,7 @@ class DeOldifyPredictor(BasePredictor): ...@@ -36,8 +36,7 @@ class DeOldifyPredictor(BasePredictor):
self.render_factor = render_factor self.render_factor = render_factor
self.model = build_model() self.model = build_model()
if weight_path is None: if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__)) weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL)
weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL, cur_path)
state_dict = paddle.load(weight_path) state_dict = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
......
...@@ -19,7 +19,7 @@ import glob ...@@ -19,7 +19,7 @@ import glob
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from paddle.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames from ppgan.utils.video import frames2video, video2frames
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
...@@ -138,8 +138,7 @@ class EDVRPredictor(BasePredictor): ...@@ -138,8 +138,7 @@ class EDVRPredictor(BasePredictor):
self.output = os.path.join(output, 'EDVR') self.output = os.path.join(output, 'EDVR')
if weight_path is None: if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__)) weight_path = get_path_from_url(EDVR_WEIGHT_URL)
weight_path = get_path_from_url(EDVR_WEIGHT_URL, cur_path)
self.weight_path = weight_path self.weight_path = weight_path
......
...@@ -25,7 +25,7 @@ from skimage.transform import resize ...@@ -25,7 +25,7 @@ from skimage.transform import resize
from scipy.spatial import ConvexHull from scipy.spatial import ConvexHull
import paddle import paddle
from paddle.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from ppgan.utils.animate import normalize_kp from ppgan.utils.animate import normalize_kp
from ppgan.modules.keypoint_detector import KPDetector from ppgan.modules.keypoint_detector import KPDetector
from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator
...@@ -78,10 +78,11 @@ class FirstOrderPredictor(BasePredictor): ...@@ -78,10 +78,11 @@ class FirstOrderPredictor(BasePredictor):
} }
if weight_path is None: if weight_path is None:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams' vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
cur_path = os.path.abspath(os.path.dirname(__file__)) weight_path = get_path_from_url(vox_cpk_weight_url)
weight_path = get_path_from_url(vox_cpk_weight_url, cur_path)
self.weight_path = weight_path self.weight_path = weight_path
if not os.path.exists(output):
os.makedirs(output)
self.output = output self.output = output
self.relative = relative self.relative = relative
self.adapt_scale = adapt_scale self.adapt_scale = adapt_scale
......
...@@ -22,7 +22,7 @@ from tqdm import tqdm ...@@ -22,7 +22,7 @@ from tqdm import tqdm
import paddle import paddle
from ppgan.models.generators import RRDBNet from ppgan.models.generators import RRDBNet
from ppgan.utils.video import frames2video, video2frames from ppgan.utils.video import frames2video, video2frames
from paddle.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
REALSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams' REALSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams'
...@@ -34,8 +34,7 @@ class RealSRPredictor(BasePredictor): ...@@ -34,8 +34,7 @@ class RealSRPredictor(BasePredictor):
self.output = os.path.join(output, 'RealSR') self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23) self.model = RRDBNet(3, 3, 64, 23)
if weight_path is None: if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__)) weight_path = get_path_from_url(REALSR_WEIGHT_URL)
weight_path = get_path_from_url(REALSR_WEIGHT_URL, cur_path)
state_dict = paddle.load(weight_path) state_dict = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix # code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import random import random
import numpy as np import numpy as np
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
import paddle import paddle
import numbers import numbers
...@@ -59,14 +73,12 @@ class DictDataLoader(): ...@@ -59,14 +73,12 @@ class DictDataLoader():
place = paddle.CUDAPlace(ParallelEnv().dev_id) \ place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
sampler = DistributedBatchSampler( sampler = DistributedBatchSampler(self.dataset,
self.dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True if is_train else False, shuffle=True if is_train else False,
drop_last=True if is_train else False) drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader( self.dataloader = paddle.io.DataLoader(self.dataset,
self.dataset,
batch_sampler=sampler, batch_sampler=sampler,
places=place, places=place,
num_workers=num_workers) num_workers=num_workers)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A modified image folder class """A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
...@@ -10,9 +23,20 @@ import os ...@@ -10,9 +23,20 @@ import os
import os.path import os.path
IMG_EXTENSIONS = [ IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG', '.jpg',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.JPG',
'.tif', '.TIF', '.tiff', '.TIFF', '.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
'.tif',
'.TIF',
'.tiff',
'.TIFF',
] ]
...@@ -38,12 +62,14 @@ def default_loader(path): ...@@ -38,12 +62,14 @@ def default_loader(path):
class ImageFolder(Dataset): class ImageFolder(Dataset):
def __init__(self,
def __init__(self, root, transform=None, return_paths=False, root,
transform=None,
return_paths=False,
loader=default_loader): loader=default_loader):
imgs = make_dataset(root) imgs = make_dataset(root)
if len(imgs) == 0: if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n" raise (RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + "Supported image extensions are: " +
",".join(IMG_EXTENSIONS))) ",".join(IMG_EXTENSIONS)))
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2 import cv2
import paddle import paddle
import os.path import os.path
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2 import cv2
import paddle import paddle
from .base_dataset import BaseDataset, get_transform from .base_dataset import BaseDataset, get_transform
from .image_folder import make_dataset from .image_folder import make_dataset
from .builder import DATASETS from .builder import DATASETS
from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class SingleDataset(BaseDataset): class SingleDataset(BaseDataset):
""" """
""" """
def __init__(self, cfg): def __init__(self, cfg):
"""Initialize this dataset class. """Initialize this dataset class.
...@@ -20,7 +34,7 @@ class SingleDataset(BaseDataset): ...@@ -20,7 +34,7 @@ class SingleDataset(BaseDataset):
BaseDataset.__init__(self, cfg) BaseDataset.__init__(self, cfg)
self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size)) self.A_paths = sorted(make_dataset(cfg.dataroot, cfg.max_dataset_size))
input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
self.transform = get_transform(cfg.transform, grayscale=(input_nc == 1)) self.transform = build_transforms(self.cfg.transforms)
def __getitem__(self, index): def __getitem__(self, index):
"""Return a data point and its metadata information. """Return a data point and its metadata information.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import mmcv # import mmcv
import os import os
import cv2 import cv2
...@@ -33,13 +47,15 @@ def scandir(dir_path, suffix=None, recursive=False): ...@@ -33,13 +47,15 @@ def scandir(dir_path, suffix=None, recursive=False):
yield rel_path yield rel_path
else: else:
if recursive: if recursive:
yield from _scandir( yield from _scandir(entry.path,
entry.path, suffix=suffix, recursive=recursive) suffix=suffix,
recursive=recursive)
else: else:
continue continue
return _scandir(dir_path, suffix=suffix, recursive=recursive) return _scandir(dir_path, suffix=suffix, recursive=recursive)
def paired_paths_from_folder(folders, keys, filename_tmpl): def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders. """Generate paired paths from folders.
""" """
...@@ -70,6 +86,7 @@ def paired_paths_from_folder(folders, keys, filename_tmpl): ...@@ -70,6 +86,7 @@ def paired_paths_from_folder(folders, keys, filename_tmpl):
(f'{gt_key}_path', gt_path)])) (f'{gt_key}_path', gt_path)]))
return paths return paths
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop. """Paired random crop.
...@@ -180,7 +197,6 @@ def augment(imgs, hflip=True, rotation=True, flows=None): ...@@ -180,7 +197,6 @@ def augment(imgs, hflip=True, rotation=True, flows=None):
@DATASETS.register() @DATASETS.register()
class SRImageDataset(Dataset): class SRImageDataset(Dataset):
"""Paired image dataset for image restoration.""" """Paired image dataset for image restoration."""
def __init__(self, cfg): def __init__(self, cfg):
super(SRImageDataset, self).__init__() super(SRImageDataset, self).__init__()
self.cfg = cfg self.cfg = cfg
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy import copy
import traceback import traceback
import paddle import paddle
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import random import random
import numbers import numbers
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2 import cv2
import random import random
import os.path import os.path
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import time import time
import copy import copy
...@@ -55,11 +69,8 @@ class Trainer: ...@@ -55,11 +69,8 @@ class Trainer:
def distributed_data_parallel(self): def distributed_data_parallel(self):
strategy = paddle.distributed.prepare_context() strategy = paddle.distributed.prepare_context()
for name in self.model.model_names: for net_name, net in self.model.nets.items():
if isinstance(name, str): self.model.nets[net_name] = paddle.DataParallel(net, strategy)
net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name,
paddle.DataParallel(net, strategy))
def train(self): def train(self):
reader_cost_averager = TimeAverager() reader_cost_averager = TimeAverager()
...@@ -77,9 +88,9 @@ class Trainer: ...@@ -77,9 +88,9 @@ class Trainer:
self.model.set_input(data) self.model.set_input(data)
self.model.optimize_parameters() self.model.optimize_parameters()
batch_cost_averager.record( batch_cost_averager.record(time.time() - step_start_time,
time.time() - step_start_time, num_samples=self.cfg.get(
num_samples=self.cfg.get('batch_size', 1)) 'batch_size', 1))
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average() self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average() self.step_time = batch_cost_averager.get_average()
...@@ -277,13 +288,13 @@ class Trainer: ...@@ -277,13 +288,13 @@ class Trainer:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
for opt_name, opt in self.model.optimizers.items(): for opt_name, opt in self.model.optimizers.items():
opt.set_dict(state_dicts[opt_name]) opt.set_state_dict(state_dicts[opt_name])
def load(self, weight_path): def load(self, weight_path):
state_dicts = load(weight_path) state_dicts = load(weight_path)
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dlib_utils import detect, crop, landmarks, crop_from_array from .dlib_utils import detect, crop, landmarks, crop_from_array
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import cv2 import cv2
from io import BytesIO from io import BytesIO
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -5,9 +19,12 @@ import cv2 ...@@ -5,9 +19,12 @@ import cv2
from PIL import Image from PIL import Image
import paddle import paddle
import paddle.vision.transforms as T import paddle.vision.transforms as T
from paddle.utils.download import get_path_from_url
import pickle import pickle
from .model import BiSeNet from .model import BiSeNet
BISENET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/bisnet.pdparams'
class FaceParser: class FaceParser:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
...@@ -33,8 +50,8 @@ class FaceParser: ...@@ -33,8 +50,8 @@ class FaceParser:
18: 0 18: 0
} }
#self.dict = paddle.to_tensor(mapper) #self.dict = paddle.to_tensor(mapper)
self.save_pth = osp.split( self.save_pth = get_path_from_url(BISENET_WEIGHT_URL,
osp.realpath(__file__))[0] + '/resnet.pdparams' osp.split(osp.realpath(__file__))[0])
self.net = BiSeNet(n_classes=19) self.net = BiSeNet(n_classes=19)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. #Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
...@@ -26,6 +26,7 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -26,6 +26,7 @@ from paddle.fluid.dygraph.base import to_variable
try: try:
from tqdm import tqdm from tqdm import tqdm
except: except:
def tqdm(x): def tqdm(x):
return x return x
...@@ -131,7 +132,13 @@ def calculate_fid_given_img(img_fake, ...@@ -131,7 +132,13 @@ def calculate_fid_given_img(img_fake,
return fid_value return fid_value
def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None): def _get_activations(files,
model,
batch_size,
dims,
use_gpu,
premodel_path,
style=None):
if len(files) % batch_size != 0: if len(files) % batch_size != 0:
print(('Warning: number of images is not a multiple of the ' print(('Warning: number of images is not a multiple of the '
'batch size. Some samples are going to be ignored.')) 'batch size. Some samples are going to be ignored.'))
...@@ -159,8 +166,7 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, sty ...@@ -159,8 +166,7 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, sty
img_list.append(np.array(im).astype('float32')) img_list.append(np.array(im).astype('float32'))
images = np.array( images = np.array(img_list)
img_list)
else: else:
images = np.array( images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]]) [imread(str(f)).astype(np.float32) for f in files[start:end]])
...@@ -179,7 +185,7 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, sty ...@@ -179,7 +185,7 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, sty
std = np.array([0.229, 0.224, 0.225]).astype('float32') std = np.array([0.229, 0.224, 0.225]).astype('float32')
images[:] = (images[:] - mean[:, None, None]) / std[:, None, None] images[:] = (images[:] - mean[:, None, None]) / std[:, None, None]
if style=='stargan': if style == 'stargan':
pred_arr[start:end] = inception_infer(images, premodel_path) pred_arr[start:end] = inception_infer(images, premodel_path)
else: else:
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -197,7 +203,8 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, sty ...@@ -197,7 +203,8 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, sty
def inception_infer(x, model_path): def inception_infer(x, model_path):
exe = fluid.Executor() exe = fluid.Executor()
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe) [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
results = exe.run(inference_program, results = exe.run(inference_program,
feed={feed_target_names[0]: x}, feed={feed_target_names[0]: x},
fetch_list=fetch_targets) fetch_list=fetch_targets)
...@@ -210,7 +217,7 @@ def _calculate_activation_statistics(files, ...@@ -210,7 +217,7 @@ def _calculate_activation_statistics(files,
batch_size=50, batch_size=50,
dims=2048, dims=2048,
use_gpu=False, use_gpu=False,
style = None): style=None):
act = _get_activations(files, model, batch_size, dims, use_gpu, act = _get_activations(files, model, batch_size, dims, use_gpu,
premodel_path, style) premodel_path, style)
mu = np.mean(act, axis=0) mu = np.mean(act, axis=0)
...@@ -218,8 +225,13 @@ def _calculate_activation_statistics(files, ...@@ -218,8 +225,13 @@ def _calculate_activation_statistics(files,
return mu, sigma return mu, sigma
def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, def _compute_statistics_of_path(path,
premodel_path, style=None): model,
batch_size,
dims,
use_gpu,
premodel_path,
style=None):
if path.endswith('.npz'): if path.endswith('.npz'):
f = np.load(path) f = np.load(path)
m, s = f['mu'][:], f['sigma'][:] m, s = f['mu'][:], f['sigma'][:]
...@@ -231,7 +243,8 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, ...@@ -231,7 +243,8 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'):
files.append(os.path.join(root, filename)) files.append(os.path.join(root, filename))
m, s = _calculate_activation_statistics(files, model, premodel_path, m, s = _calculate_activation_statistics(files, model, premodel_path,
batch_size, dims, use_gpu, style) batch_size, dims, use_gpu,
style)
return m, s return m, s
...@@ -241,7 +254,7 @@ def calculate_fid_given_paths(paths, ...@@ -241,7 +254,7 @@ def calculate_fid_given_paths(paths,
use_gpu, use_gpu,
dims, dims,
model=None, model=None,
style = None): style=None):
assert os.path.exists( assert os.path.exists(
premodel_path premodel_path
), 'pretrain_model path {} is not exists! Please download it first'.format( ), 'pretrain_model path {} is not exists! Please download it first'.format(
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. #Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
...@@ -29,6 +43,7 @@ def reorder_image(img, input_order='HWC'): ...@@ -29,6 +43,7 @@ def reorder_image(img, input_order='HWC'):
img = img.transpose(1, 2, 0) img = img.transpose(1, 2, 0)
return img return img
def bgr2ycbcr(img, y_only=False): def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image. """Convert a BGR image to YCbCr image.
...@@ -52,16 +67,17 @@ def bgr2ycbcr(img, y_only=False): ...@@ -52,16 +67,17 @@ def bgr2ycbcr(img, y_only=False):
and range as input image. and range as input image.
""" """
img_type = img.dtype img_type = img.dtype
img = _convert_input_type_range(img) #img = _convert_input_type_range(img)
if y_only: if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else: else:
out_img = np.matmul( out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128] [65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type) #out_img = _convert_output_type_range(out_img, img_type)
return out_img return out_img
def to_y_channel(img): def to_y_channel(img):
"""Change to Y channel of YCbCr. """Change to Y channel of YCbCr.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2 import cv2
import numpy as np import numpy as np
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. #Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
...@@ -38,7 +38,8 @@ def parse_args(): ...@@ -38,7 +38,8 @@ def parse_args():
type=int, type=int,
default=1, default=1,
help='sample number in a batch for inference.') help='sample number in a batch for inference.')
parser.add_argument('--style', parser.add_argument(
'--style',
type=str, type=str,
help='calculation style: stargan or default (gan-compression style)') help='calculation style: stargan or default (gan-compression style)')
args = parser.parse_args() args = parser.parse_args()
...@@ -53,8 +54,12 @@ def main(): ...@@ -53,8 +54,12 @@ def main():
inference_model_path = args.inference_model inference_model_path = args.inference_model
batch_size = args.batch_size batch_size = args.batch_size
fid_value = calculate_fid_given_paths(paths, inference_model_path, fid_value = calculate_fid_given_paths(paths,
batch_size, args.use_gpu, 2048, style=args.style) inference_model_path,
batch_size,
args.use_gpu,
2048,
style=args.style)
print('FID: ', fid_value) print('FID: ', fid_value)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .resnet_backbone import resnet18, resnet34, resnet50, resnet101, resnet152 from .resnet_backbone import resnet18, resnet34, resnet50, resnet101, resnet152
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix # code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import os import os
import paddle import paddle
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
from ..utils.registry import Registry from ..utils.registry import Registry
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
from paddle.distributed import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .nlayers import NLayerDiscriminator from .nlayers import NLayerDiscriminator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy import copy
from ...utils.registry import Registry from ...utils.registry import Registry
DISCRIMINATORS = Registry("DISCRIMINATOR") DISCRIMINATORS = Registry("DISCRIMINATOR")
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy import copy
from ...utils.registry import Registry from ...utils.registry import Registry
GENERATORS = Registry("GENERATOR") GENERATORS = Registry("GENERATOR")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -273,7 +287,7 @@ class PixelShuffle_ICNR(nn.Layer): ...@@ -273,7 +287,7 @@ class PixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale) self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d([1, 0, 1, 0]) self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg') self.blur = nn.AvgPool2D(2, stride=1)
self.relu = relu(True, leaky=leaky) self.relu = relu(True, leaky=leaky)
def forward(self, x): def forward(self, x):
...@@ -339,7 +353,7 @@ class CustomPixelShuffle_ICNR(nn.Layer): ...@@ -339,7 +353,7 @@ class CustomPixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale) self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d([1, 0, 1, 0]) self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg') self.blur = nn.AvgPool2D(2, stride=1)
self.relu = nn.LeakyReLU( self.relu = nn.LeakyReLU(
leaky) if leaky is not None else nn.ReLU() #relu(True, leaky=leaky) leaky) if leaky is not None else nn.ReLU() #relu(True, leaky=leaky)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle import paddle
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
from paddle import nn from paddle import nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import functools import functools
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools import functools
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools import functools
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle import paddle
...@@ -10,7 +24,6 @@ class GANLoss(nn.Layer): ...@@ -10,7 +24,6 @@ class GANLoss(nn.Layer):
The GANLoss class abstracts away the need to create the target label tensor The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input. that has the same size as the input.
""" """
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class. """ Initialize the GANLoss class.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,6 +17,7 @@ import paddle ...@@ -17,6 +17,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.vision.models import vgg16 from paddle.vision.models import vgg16
from paddle.utils.download import get_path_from_url
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -29,6 +30,8 @@ from ..utils.image_pool import ImagePool ...@@ -29,6 +30,8 @@ from ..utils.image_pool import ImagePool
from ..utils.preprocess import * from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset from ..datasets.makeup_dataset import MakeupDataset
VGGFACE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/vggface.pdparams'
@MODELS.register() @MODELS.register()
class MakeupModel(BaseModel): class MakeupModel(BaseModel):
...@@ -50,8 +53,13 @@ class MakeupModel(BaseModel): ...@@ -50,8 +53,13 @@ class MakeupModel(BaseModel):
init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0) init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
if self.is_train: # define discriminators if self.is_train: # define discriminators
vgg = vgg16(pretrained=True) vgg = vgg16(pretrained=False)
self.vgg = vgg.features self.vgg = vgg.features
cur_path = os.path.abspath(os.path.dirname(__file__))
vgg_weight_path = get_path_from_url(VGGFACE_WEIGHT_URL, cur_path)
param = paddle.load(vgg_weight_path)
vgg.load_dict(param)
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator) self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator) self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0) init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
from paddle.distributed import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
...@@ -256,8 +270,10 @@ def kaiming_init(layer, ...@@ -256,8 +270,10 @@ def kaiming_init(layer,
distribution='normal'): distribution='normal'):
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if distribution == 'uniform': if distribution == 'uniform':
kaiming_uniform_( kaiming_uniform_(layer.weight,
layer.weight, a=a, mode=mode, nonlinearity=nonlinearity) a=a,
mode=mode,
nonlinearity=nonlinearity)
else: else:
kaiming_normal_(layer.weight, a=a, mode=mode, nonlinearity=nonlinearity) kaiming_normal_(layer.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(layer, 'bias') and layer.bias is not None: if hasattr(layer, 'bias') and layer.bias is not None:
...@@ -273,7 +289,6 @@ def init_weights(net, init_type='normal', init_gain=0.02): ...@@ -273,7 +289,6 @@ def init_weights(net, init_type='normal', init_gain=0.02):
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself. work better for some applications. Feel free to try yourself.
""" """
def init_func(m): # define the initialization function def init_func(m): # define the initialization function
classname = m.__class__.__name__ classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 if hasattr(m, 'weight') and (classname.find('Conv') != -1
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import functools import functools
import paddle.nn as nn import paddle.nn as nn
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .optimizer import build_optimizer from .optimizer import build_optimizer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy import copy
import paddle import paddle
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
from scipy.spatial import ConvexHull from scipy.spatial import ConvexHull
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
import yaml import yaml
__all__ = ['get_config'] __all__ = ['get_config']
...@@ -65,7 +64,6 @@ def override(dl, ks, v): ...@@ -65,7 +64,6 @@ def override(dl, ks, v):
ks(list): list of keys ks(list): list of keys
v(str): value to be replaced v(str): value to be replaced
""" """
def str2num(v): def str2num(v):
try: try:
return eval(v) return eval(v)
...@@ -104,8 +102,8 @@ def override_config(config, options=None): ...@@ -104,8 +102,8 @@ def override_config(config, options=None):
""" """
if options is not None: if options is not None:
for opt in options: for opt in options:
assert isinstance(opt, str), ( assert isinstance(opt,
"option({}) should be a str".format(opt)) str), ("option({}) should be a str".format(opt))
assert "=" in opt, ( assert "=" in opt, (
"option({}) should contain a =" "option({}) should contain a ="
"to distinguish between key and value".format(opt)) "to distinguish between key and value".format(opt))
...@@ -122,8 +120,7 @@ def get_config(fname, overrides=None, show=True): ...@@ -122,8 +120,7 @@ def get_config(fname, overrides=None, show=True):
""" """
Read config from file Read config from file
""" """
assert os.path.exists(fname), ( assert os.path.exists(fname), ('config file({}) is not exist'.format(fname))
'config file({}) is not exist'.format(fname))
config = parse_config(fname) config = parse_config(fname)
override_config(config, overrides) override_config(config, overrides)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import os.path as osp
import shutil
import requests
import hashlib
import tarfile
import zipfile
import time
from tqdm import tqdm
import logging
from .logger import get_logger
logger = get_logger('ppgan')
PPGAN_HOME = os.path.expanduser("~/.cache/ppgan/")
DOWNLOAD_RETRY_LIMIT = 3
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path_from_url(url, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert is_url(url), "downloading from {} not a url".format(url)
root_dir = PPGAN_HOME
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().local_rank == 0:
if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
fullpath = _decompress(fullpath)
return fullpath
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logger.info("Downloading {} from {} to {}".format(fname, url, fullname))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].strip(os.sep).split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
file_name = file_list[0].split(os.sep)[0]
for i in range(1, len(file_list)):
if file_name != file_list[i].split(os.sep)[0]:
return False
return True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import six import six
import pickle import pickle
...@@ -6,7 +20,11 @@ import paddle ...@@ -6,7 +20,11 @@ import paddle
def makedirs(dir): def makedirs(dir):
if not os.path.exists(dir): if not os.path.exists(dir):
# avoid error when train with multiple gpus
try:
os.makedirs(dir) os.makedirs(dir)
except:
pass
def save(state_dicts, file_name): def save(state_dicts, file_name):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random import random
import paddle import paddle
...@@ -8,7 +22,6 @@ class ImagePool(): ...@@ -8,7 +22,6 @@ class ImagePool():
This buffer enables us to update discriminators using a history of generated images This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators. rather than the ones produced by the latest generators.
""" """
def __init__(self, pool_size): def __init__(self, pool_size):
"""Initialize the ImagePool class """Initialize the ImagePool class
...@@ -44,7 +57,8 @@ class ImagePool(): ...@@ -44,7 +57,8 @@ class ImagePool():
else: else:
p = random.uniform(0, 1) p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive random_id = random.randint(0, self.pool_size -
1) # randint is inclusive
# FIXME: clone # FIXME: clone
# tmp = (self.images[random_id]).detach() #.clone() # tmp = (self.images[random_id]).detach() #.clone()
tmp = self.images[random_id] #.clone() tmp = self.images[random_id] #.clone()
...@@ -52,5 +66,6 @@ class ImagePool(): ...@@ -52,5 +66,6 @@ class ImagePool():
return_images.append(tmp) return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image else: # by another 50% chance, the buffer will return the current image
return_images.append(image) return_images.append(image)
return_images = paddle.concat(return_images, 0) # collect all the images and return return_images = paddle.concat(return_images,
0) # collect all the images and return
return return_images return return_images
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os import os
import sys import sys
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
logger_initialized = {}
def setup_logger(output=None, name="ppgan"): def setup_logger(output=None, name="ppgan"):
""" """
...@@ -53,3 +69,11 @@ def setup_logger(output=None, name="ppgan"): ...@@ -53,3 +69,11 @@ def setup_logger(output=None, name="ppgan"):
logger.addHandler(fh) logger.addHandler(fh)
return logger return logger
def get_logger(name, output=None):
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
return setup_logger(name=name, output=name)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Registry(object): class Registry(object):
""" """
The registry that provides name -> object mapping, to support third-party users' custom modules. The registry that provides name -> object mapping, to support third-party users' custom modules.
...@@ -13,7 +28,6 @@ class Registry(object): ...@@ -13,7 +28,6 @@ class Registry(object):
.. code-block:: python .. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone) BACKBONE_REGISTRY.register(MyBackbone)
""" """
def __init__(self, name): def __init__(self, name):
""" """
Args: Args:
...@@ -26,7 +40,8 @@ class Registry(object): ...@@ -26,7 +40,8 @@ class Registry(object):
def _do_register(self, name, obj): def _do_register(self, name, obj):
assert ( assert (
name not in self._obj_map name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name) ), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name)
self._obj_map[name] = obj self._obj_map[name] = obj
def register(self, obj=None, name=None): def register(self, obj=None, name=None):
...@@ -52,6 +67,8 @@ class Registry(object): ...@@ -52,6 +67,8 @@ class Registry(object):
def get(self, name): def get(self, name):
ret = self._obj_map.get(name) ret = self._obj_map.get(name)
if ret is None: if ret is None:
raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) raise KeyError(
"No object named '{}' found in '{}' registry!".format(
name, self._name))
return ret return ret
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import time import time
import paddle import paddle
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
def video2frames(video_path, outpath, **kargs): def video2frames(video_path, outpath, **kargs):
def _dict2str(kargs): def _dict2str(kargs):
cmd_str = '' cmd_str = ''
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -17,7 +31,9 @@ def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8): ...@@ -17,7 +31,9 @@ def tensor2img(input_image, min_max=(-1., 1.), imtype=np.uint8):
image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = image_numpy.clip(min_max[0], min_max[1]) image_numpy = image_numpy.clip(min_max[0], min_max[1])
image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0]) image_numpy = (image_numpy - min_max[0]) / (min_max[1] - min_max[0])
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling image_numpy = (np.transpose(
image_numpy,
(1, 2, 0))) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing else: # if it is a numpy array, do nothing
image_numpy = input_image image_numpy = input_image
return image_numpy.astype(imtype) return image_numpy.astype(imtype)
......
FILE=$1
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./$FILE.tar.gz
TARGET_DIR=./$FILE/
wget -N $URL -O $TAR_FILE --no-check-certificate
mkdir $TARGET_DIR
tar -zxvf $TAR_FILE -C ../data/
rm $TAR_FILE
rm -rf $TARGET_DIR
...@@ -36,9 +36,7 @@ setup( ...@@ -36,9 +36,7 @@ setup(
description='Awesome GAN toolkits based on PaddlePaddle', description='Awesome GAN toolkits based on PaddlePaddle',
url='https://github.com/PaddlePaddle/PaddleGAN', url='https://github.com/PaddlePaddle/PaddleGAN',
download_url='https://github.com/PaddlePaddle/PaddleGAN.git', download_url='https://github.com/PaddlePaddle/PaddleGAN.git',
keywords=[ keywords=['gan paddlegan'],
'gan paddlegan'
],
classifiers=[ classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent', 'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)', 'Natural Language :: Chinese (Simplified)',
...@@ -46,4 +44,5 @@ setup( ...@@ -46,4 +44,5 @@ setup(
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities' 'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], ) ],
\ No newline at end of file )
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册